CV Week: Итоговое задание¶

На лекции и семинаре мы разбирали как дистиллировать многошаговую диффузионную модель в малошагового студента, и тем самым будет работать на порядок быстрее учителя.

Один из подходов, который мы разбирали Consistency Distillation. В этом задании, мы закрепим материал, который был на лекции и семинаре и реализуем этот фреймворк, затрагивая различные нюансы.

В этом задании мы будем дистиллировать модель Stable Diffusion 1.5 (SD1.5) для генерации картинок по текстовому описанию.

Вам предстоит выполнить 8 небольших заданий, которые приведут нас к неплохой модели для генерации картинок за 4 шага, работая в органиченных условиях колаба.

In [1]:
# torch 2.4.1+cu124
# !pip install diffusers==0.30.3, peft==0.8.2 huggingface_hub==0.23.4

Теормин¶


Диффузионные модели¶

Задан прямой диффузионный процесс, который переводит чистые картинки в шум с помощью распределения $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$

Таким образом, мы можем получаться зашумленные картинки по следующей формуле: $\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$ (1)

$\alpha_t, \sigma_t$ задают процесс зашумления. Здесь мы будем иметь дело с variance preserving (VP) процессом, т. е., $\alpha^2_t = 1 - \sigma^2_t$.

Диффузионная модель (ДМ) пытается решить обратную задачу: из шума порождать новые картинки. Важно, что диффузионный процесс можно описать следующим обыкновенным дифференциальным уравнением (ОДУ):

$dx = \left[ f(\mathbf{x}, t) - \frac{1}{2} \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}) \right] dt$, (2)

где $f(\mathbf{x}, t)$ известен из заданного процесса зашумления, а $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$ (скор функцию) оцениваем с помощью нейросети: $s_\theta(\mathbf{x}_t, t) \approx \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$. Таким образом, имея оценку на $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x})$, мы можем решить это ОДУ, стартуя со случайного шума, и получить картинку.

SD1.5 использует $\epsilon$-параметризацию, т.е., UNet пытается предсказать шум, который мы добавили на картинку по формуле (1). Оценку скор функции можно получить, пользуясь результатом, вытекающим из формулы Твидди: $s_\theta(\mathbf{x}_t, t) = - \frac{\epsilon_\theta(\mathbf{x}_t, t)} { \sigma_t}$

Чтобы решить ОДУ (2), нам нужно воспользоваться каким-то численным методом (солвером). В этом задании мы будем работать с не самым эффектным, но самым популярным солвером: DDIM, который является адаптированным методом Эйлера под диффузионный ОДУ.

Для VP процесса переход с помощью DDIM с шага $t$ на $s$ можно сделать следующим образом:

$ x_s = DDIM(\mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $

Этот переход можно интерпретировать так: получаем оценку на чистую картинку $\mathbf{x}_0$ на шаге $t$, используя $\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t}$, а потом снова зашумляем эту оценку на шаг $s$ по формуле (1), но только используем не случайный шум, а шум предсказанный моделью $\epsilon_\theta$.

Используя DDIM для SD1.5, можем получать хорошие картинки за 50 шагов.

SD1.5 - латентная ДМ, т.е. модель работает не в пиксельном пространстве, а в латентном пространстве VAE. Таким образом SD1.5 состоит из следующих компонент:

  1. VAE - переводит $3{\times}512{\times}512$ картинки в латенты $4{\times}64{\times}64$ и может декодировать их обратно в картинки.
  2. Текстовый энкодер - извлекает текстовые признаки из промпта. Эти признаки будут подаваться в диффузионную модель, чтобы дать модели информацию, что именно хотим сгенерировать
  3. Диффузионная модель - UNet, работающий на "латентных картинках" $4{\times}64{\times}64$.

Консистенси модели¶

Общая идея¶

Главная цель дистилляции диффузии - уменьшить количество шагов ДМ, при этом сохранив высокое качество картинок.

Консистенси модели (Consistency Models | CM) - класс моделей, где мы хотим выучить "консистенси функцию" $f_\theta(\mathbf{x}_t)$ - с любой точки $\mathbf{x}_{t}$ траектории диффузионного ОДУ (2) сразу предсказывать $\mathbf{x}_{0}$ (чистые данные) за один шаг. Если мы идеально выучим консистенси функцию, то сможем шагать из чистого шума сразу в картинку, что супер эффективно в отличии от генерации ДМ.

Отметим, что консистенси модель можно учить как независимую генеративную модель, без предобученной ДМ, и в задании 3 вам предстоит подумать, как это можно сделать.


No description has been provided for this image

Консистенси дистилляция (Consistency Distillation | CD) - подход, когда для обучения CM, мы используем предобученную ДМ. ДМ нам дает качественную инициализацию модели и уже обученную скор функцию, что сильно упрощает сходимость консистенси моделей.

Обучение CM¶

No description has been provided for this image

Главная принцип обучения консистенси моделей заключается в попытке удовлетворить self-consistency св-ву: выход CM на двух соседних точках траектории $\mathbf{x}_{t}$ и $\mathbf{x}_{t-1}$ должен совпадать по какой-то мере близости, например L2 расстояние: $\lVert f_\theta(\mathbf{x}_{t-1}) - f_\theta(\mathbf{x}_{t}) \rVert^2_2$.

Заметим, что self-consistency св-во удовлетворить очень просто без какого-либо обучения, взяв, например $f_\theta(\mathbf{x}_{t}) \equiv 0$.

Поэтому, чтобы избежать вырожденных решений, нам необходимо выставить граничное условие (boundary condition), которое будет требовать, чтобы в самой левой точке траектории около 0, модель предсказывала картинку, которую получает на вход: $f_\theta(\mathbf{x}_{\epsilon}) = \mathbf{x}_{\epsilon}$.

Практическое замечание: Для обеих точек траектории мы применяем одну и ту же модель $f_\theta(\cdot)$. Но выход модели на шаге ${t-1}$ является "таргетом" для выхода модели на шаге $t$ и поэтому выполнение модели для шага $t-1$ выполняется в torch.no_grad режиме.

Как получаться две соседние точки на траектории ОДУ?

Берем случайную картинку $\mathbf{x}_0$ из датасета.

Точку $\mathbf{x}_t$ получаем с помощью прямого процесса зашумления: $\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$

Чтобы получить соседнюю точку $\mathbf{x}_{t-1}$, нам нужно сделать шаг по траектории ОДУ, используя, например, DDIM солвер.

В консистенси дистилляции, мы делаем шаг предобученной ДМ: $\mathbf{x}_{t-1} = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, t-1)$

Важно: на практике мы можем брать не соседние шаги $t$ и $t-1$, а с некоторым интервалом, например 20 шагов. Размер интервала влияет на bias/variance trade-off в консистенси обучении: больше интервал между шагами - больше смещение, но меньше дисперсия, и наоборот. Для простоты в этом задании мы зафиксируем интервал - 20 шагов, но во многих работах размер интервала динамически меняют по ходу обучения.

In [2]:
from tqdm.auto import tqdm

import csv
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, LCMScheduler, UNet2DConditionModel, DDIMScheduler

from peft import LoraConfig, get_peft_model, get_peft_model_state_dict

%matplotlib inline
import matplotlib.pyplot as plt
2024-12-09 22:53:22.236387: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-09 22:53:22.236469: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-09 22:53:22.302624: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-09 22:53:22.442945: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-09 22:53:23.908561: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
In [3]:
# ---------------------
# Visualization utils
# ---------------------


def visualize_images(images):
    assert len(images) == 4
    plt.figure(figsize=(12, 3))
    for i, image in enumerate(images):
        plt.subplot(1, 4, i + 1)
        plt.imshow(image)
        plt.axis("off")

    plt.subplots_adjust(wspace=-0.01, hspace=-0.01)


# --------------
# Tensor utils
# --------------


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


# ---------------
# Dataset utils
# ---------------


class COCODataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, subset_name="train2014_5k", transform=None, max_cnt=None):
        """
        Arguments:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.extensions = (
            ".jpg",
            ".jpeg",
            ".png",
            ".ppm",
            ".bmp",
            ".pgm",
            ".tif",
            ".tiff",
            ".webp",
        )
        sample_dir = os.path.join(root_dir, subset_name)

        # Collect sample paths
        self.samples = sorted(
            [
                os.path.join(sample_dir, fname)
                for fname in os.listdir(sample_dir)
                if fname[-4:] in self.extensions
            ],
            key=lambda x: x.split("/")[-1].split(".")[0],
        )
        self.samples = (
            self.samples if max_cnt is None else self.samples[:max_cnt]
        )  # restrict num samples

        # Collect captions
        self.captions = {}
        with open(os.path.join(root_dir, f"{subset_name}.csv"), newline="\n") as csvfile:
            spamreader = csv.reader(csvfile, delimiter=",")
            for i, row in enumerate(spamreader):
                if i == 0:
                    continue
                self.captions[row[1]] = row[2]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample_path = self.samples[idx]
        sample = Image.open(sample_path).convert("RGB")

        if self.transform:
            sample = self.transform(sample)

        return {
            "image": sample,
            "text": self.captions[os.path.basename(sample_path)],
            "idxs": idx,
        }

Модель учителя (SD1.5)¶

Задание №1¶

Давайте для начала загрузим модель StableDiffusion 1.5 и сгенерируем ей картинки за 50 шагов.

Важно: для экономии памяти, загружаем все компоненты модели в FP16. Не забываем положить модель на GPU.

In [4]:
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    variant="fp16",
)

pipe = pipe.to("cuda")

# Проверяем, что все компоненты модели в FP16 и на cuda
assert pipe.unet.dtype == torch.float16 and pipe.unet.device.type == "cuda"
assert pipe.vae.dtype == torch.float16 and pipe.vae.device.type == "cuda"
assert pipe.text_encoder.dtype == torch.float16 and pipe.text_encoder.device.type == "cuda"

# Заменяем дефолтный сэмплер на DDIM
pipe.scheduler = DDIMScheduler.from_config(
    pipe.scheduler.config,
    timestep_spacing="trailing",
)
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()

# Отдельно извлечем модель учителя, которую потом будем дистиллировать
teacher_unet = pipe.unet
Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Теперь сгенерируем картинки за 50 шагов. Вам нужно написать вызов pipe и передать в него промпт, число шагов генерации, генератор случайных чисел, гайденс скейл и указать, чтобы сгенерировалось 4 картинки на промпт.

In [5]:
prompt = "A sad puppy with large eyes"
guidance_scale = 7.5
generator = torch.Generator("cuda").manual_seed(1)

# generate 4 images
images = pipe(
    prompt,
    guidance_scale=guidance_scale,
    generator=generator,
    num_images_per_prompt=4,
    num_inference_steps=50,
).images  # type: ignore

visualize_images(images)
  0%|          | 0/50 [00:00<?, ?it/s]
No description has been provided for this image

Давайте посмотрим, что выдаст модель за 4 шага. Все то же самое, что и выше, просто поменяем число шагов.

In [6]:
generator = torch.Generator("cuda").manual_seed(1)

images = pipe(
    prompt,
    # guidance_scale=guidance_scale,
    generator=generator,
    num_images_per_prompt=4,
    num_inference_steps=4,
).images  # type: ignore

visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

На 4 шагах картинки получаются размазанными. Давайте постараемся починить их.

Создаем датасет¶

Чтобы ДЗ было легко выполнимым на colab, мы будем учить консистенси модели на небольшой обучающей выборке из 5000 пар текст-картинка из COCO датасета. Интересное свойство консистенси моделей - они могут сходиться до адекватного качества за несколько сотен шагов. Качество все еще будет не идеальным, но фазовый переход уже должен быть заметен.

Данные можно загрузить с помощью команд в ячейке ниже. В локальной текущей директории ./ должны появиться:

  • Папка train2014_5k с 5000 картинками
  • Файл train2014_5k.csv с 5000 промптами

Данные парсятся корректным образом в уже реализованном классе COCODataset.

In [7]:
# !wget https://storage.yandexcloud.net/yandex-research/train2014_5k.tar.gz
# !tar -xzf train2014_5k.tar.gz

Замечание: для более быстрого дебаггинга можете взять, например, 2500 картинок и прогнать на всей выборке только в самом конце. 2500 картинок должно быть достаточно для понимания корректно ли реализованы функции. Совсем для первичного дебаггинга можно взять еще меньше картинок.

In [8]:
from torchvision import transforms

transform = transforms.Compose(
    [
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        lambda x: 2 * x - 1,
    ]
)
dataset = COCODataset(
    ".",
    subset_name="train2014_5k",
    transform=transform,
    # max_cnt=2500,
)
# assert len(dataset) == 2500  # 2500
assert len(dataset) == 5000

batch_size = 8  # Рекоммендуемы размер батча на Colab

train_dataloader = torch.utils.data.DataLoader(
    dataset=dataset, shuffle=True, batch_size=batch_size, drop_last=True
)
In [9]:
@torch.no_grad()
def prepare_batch(batch, pipe):
    """
    Предобработка батча картинок и текстовых промптов.
    Маппим картинки в латентное пространство VAE.
    Извлекаем эмбеды промптов с помощью текстового энкодера.

    Params:

    Return:
        latents: torch.Tensor([B, 4, 64, 64], dtype=torch.float16)
        prompt_embeds: torch.Tensor([B, 77, D], dtype=torch.float16)
    """

    # Токенизируем промпты
    text_inputs = pipe.tokenizer(
        batch["text"],
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )

    # Извлекаем эмбеды промптов с помощью текстового энкодера
    prompt_embeds = pipe.text_encoder(text_inputs.input_ids.cuda())[0]

    # Переводим картинки в латентное пространство VAE
    image = batch["image"].to("cuda", dtype=torch.float16)
    latents = pipe.vae.encode(image).latent_dist.sample()
    latents = latents * pipe.vae.config.scaling_factor
    return latents, prompt_embeds

Подготовка моделей и оптимизатора¶

Для начала создаем обучаемую модель: UNet инициализируемый весами SD1.5. Вам нужно воспользоваться классом UNet2DConditionModel и загрузить отдельно только UNet модель из SD1.5.

Отметим, что эта модель у нас будет храниться в полной точности FP32, потому что обучение параметров в FP16 может приводить к нестабильностям и низкому качеству.

In [10]:
unet = UNet2DConditionModel.from_pretrained(
    "sd-legacy/stable-diffusion-v1-5",
    subfolder="unet",
    device_map="balanced",
)
unet.train()

assert unet.dtype == torch.float32
assert unet.training

Для экономии памяти во время обучения будем учить не параметры самой модели, а добавим в нее обучаемые LoRA адаптеры с малым числом параметров.

LoRA представляет собой маленькую добавку к весам модели, где на одну матрицу весов $W \in \mathbb{R}^{m{\times}n} $ обучаются две низкоранговые матрицы $W_A \in \mathbb{R}^{k{\times}n}$ и $W_B \in \mathbb{R}^{k{\times}m}$, где $k$ - ранг матрицы сильно меньше $m$ и $n$.

Тем самым, новая обученная матрица весов может быть представлена как $\hat{W} = W + \Delta W = W + W^T_B W_A$.
Во время инференса $\Delta W$ можно вмержить в $W$ и получить итоговую модель. Также частая практика оставлять адаптеры как есть, чтобы была возможность для одной базовой модели учить несколько адаптеров под разные задачи и переключаться между ними по необходимости.

Если не мержить адаптеры, то вычисления для линейного слоя происходят как на картинке ниже.

No description has been provided for this image
In [19]:
# Указываем к каким слоям модели мы будет добавлять адаптеры.
lora_modules = [
    "to_q",
    "to_k",
    "to_v",
    "to_out.0",
    "proj_in",
    "proj_out",
    "ff.net.0.proj",
    "ff.net.2",
    "conv1",
    "conv2",
    "conv_shortcut",
    "downsamplers.0.conv",
    "upsamplers.0.conv",
    "time_emb_proj",
]
lora_config = LoraConfig(r=64, target_modules=lora_modules)  # задает ранг у матриц A и B в LoRA.

# Создаем обертку исходной UNet модели с LoRA адаптерами, используя библиотеку PEFT
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")

# Включаем gradient checkpointing - важная техника для экономии памяти во время обучения
cm_unet.enable_gradient_checkpointing()

# Создаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)

# Задаем лосс функцию для CM обжектива. В базовом варианте разумно взять L2
# По умолчанию, она уже выдает усредненное значение по всем размерностям
mse_loss = torch.nn.functional.mse_loss

Задание №2 (0.5 балла, сдается в контесте)¶

Реализация шага DDIM¶

Шаг с помощью DDIM с $\mathbf{x}_t$ на $\mathbf{x}_s$ можно сделать следующим образом:

$ \mathbf{x}_s = DDIM(\epsilon_\theta, \mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $

Вам нужно реализовать эту формулу в уже готовом шаблоне ниже. Чтобы корректно выполнить задание, вам нужно задать $\alpha_t$ и $\sigma_t$ имея DDIMScheduler. **Обратите внимание на аттрибут *scheduler.alphas_cumprod***, который задает $\bar\alpha_{t} = \prod^t_{i=1} (1-\beta_i)$ в классической DDPM формулировке: Denoising Diffusion Probabilistic Models.

In [12]:
def ddim_solver_step(model_output, x_t, t, s, scheduler):  # -> Any:
    """
    Шаг DDIM солвера для VP процесса зашумления и eps-prediction модели
    params:
        model_output: torch.Tensor[B, 4, 64, 64] - предсказание модели - шум eps
        x_t: torch.Tensor[B, 4, 64, 64] - сэмплы на шаге t
        t: torch.Tensor[B] - номер текущего шага
        s: torch.Tensor[B] - номер следующего шага
        scheduler: DDIMScheduler - расписание диффузионного процесса, чтобы получить alpha и sigma
    """
    alphas = torch.sqrt(scheduler.alphas_cumprod)
    sigmas = torch.sqrt(1.0 - scheduler.alphas_cumprod)

    sigmas_s = extract_into_tensor(sigmas, s, x_t.shape)
    alphas_s = extract_into_tensor(alphas, s, x_t.shape)

    sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
    alphas_t = extract_into_tensor(alphas, t, x_t.shape)

    # Выставляем крайние значения alpha и sigma, чтобы выполнялись граничные условия
    alphas_s[s == 0] = 1.0
    sigmas_s[s == 0] = 0.0

    alphas_t[t == 0] = 1.0
    sigmas_t[t == 0] = 0.0

    # Reverse diffusion formula
    x_0 = (x_t - sigmas_t * model_output) / alphas_t

    # DDIM formula
    x_s = alphas_s * x_0 + sigmas_s * model_output

    return x_s

Реализация процесса зашумления (q sample)¶

Аналогично, нам нужен процесс зашумления $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$

$\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$

In [13]:
def q_sample(x, t, scheduler, noise=None):
    alphas = torch.sqrt(scheduler.alphas_cumprod)
    sigmas = torch.sqrt(1.0 - scheduler.alphas_cumprod)

    if noise is None:
        noise = torch.randn_like(x)

    sigmas_t = extract_into_tensor(sigmas, t, x.shape)
    alphas_t = extract_into_tensor(alphas, t, x.shape)

    x_t = x * alphas_t + sigmas_t * noise
    return x_t

Consistency Training¶

Обучение консистенси моделей без учителя называется Consistency Training (CT). В таком случае CM можно рассматривать как отдельный вид генеративных моделей. Давайте начнем именно с этого подхода и обучим нашу первую консистенси модель на базе SD1.5.

Задание №3¶

Задание №3.1 (0.5 балла, сдается в контесте)¶

В консиcтенси дистилляции модель учителя используется для получения второй точки на траектории ODE. Можем ли мы попробовать оценить соседнюю точку аналитически?

Вам предлагается вывести это самим, используя формулу DDIM шага выше и вспомнив, как мы оцениваем скор функции в denoising score matching-e:

$\epsilon_\theta(x_t, t) = - \sigma_t s_\theta(x_t, t)$

$s_\theta(x_t, t) \approx \nabla_{x_t} \log q(x_t) = \mathop{\mathbb{E}}_{\mathbf{x}\sim p_{data}}\left [ \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t | \mathbf{x}) \vert \mathbf{x}_t \right ] \approx \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x})$


< YOUR DERIVATION HERE > $$ \mathbf{x}_s = DDIM(\epsilon_\theta, \mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta = \\ \alpha_s \cdot \left(\frac{\mathbf{x}_t + \sigma_t^2 s_\theta(x_t, t)}{\alpha_t} \right) - \sigma_s \sigma_t s_\theta(x_t, t) = \\ \alpha_s \cdot \left(\frac{\mathbf{x}_t + \sigma_t^2 \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x})}{\alpha_t} \right) - \sigma_s \sigma_t \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x}) = \\ \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t^2 \cdot \left(\frac{\mathbf{x}_t - \alpha_t \mathbf{x}_0}{\sigma_t^2}\right)}{\alpha_t} \right) + \sigma_s \sigma_t \cdot \left(\frac{\mathbf{x}_t - \alpha_t \mathbf{x}_0}{\sigma_t^2}\right) = \\ \alpha_s \mathbf{x}_0 + \sigma_s \cdot \left(\frac{\mathbf{x}_t - \alpha_t \mathbf{x}_0}{\sigma_t}\right) $$


Если возникнут трудность, можно обратиться к оригинальной статье.

Теперь реализуем то, что у вас получилось в функции ниже.

In [ ]:
def get_xs_from_xt_naive(x_0, x_t, t, s, scheduler, noise=None, **kwargs):
    """
    Получение точки x_s в CT режиме, т.е., аналитически.
    """
    if x_0 is None:
        x_0 = torch.zeros_like(x_t)

    if x_t is None:
        x_t = q_sample(x_0, t, scheduler, noise=noise)

    if (t == s).all():
        return x_t

    alphas = torch.sqrt(scheduler.alphas_cumprod)
    sigmas = torch.sqrt(1.0 - scheduler.alphas_cumprod)

    alpha_t = extract_into_tensor(alphas, t, x_t.shape)
    sigma_t = extract_into_tensor(sigmas, t, x_t.shape)
    alpha_s = extract_into_tensor(alphas, s, x_t.shape)
    sigma_s = extract_into_tensor(sigmas, s, x_t.shape)

    alpha_t[t == 0] = 1.0
    sigma_t[t == 0] = 0.0

    x_s = x_0.clone().detach()

    cond_1 = (sigma_t != 0) * (sigma_s != 0)
    x_s = torch.where(
        cond_1,
        alpha_s * x_0 + sigma_s * (x_t - alpha_t * x_0) / sigma_t,
        x_s,
    )
    cond_2 = (sigma_t == 0) * (sigma_s != 0)
    x_s = torch.where(
        cond_2,
        q_sample(x_0, s, scheduler, noise),
        x_s,
    )

    return x_s

Задание №3.2¶

Ниже представлен шаблон функции, которая считает лосс для консистенси моделей. Вам нужно правильно заполнить пропуски, чтобы получилась корректная функция.

In [15]:
def cm_loss_template(
    latents,
    prompt_embeds,  # батч латентов и текстовых эмбедов
    unet,
    scheduler,
    # Функции, которые будем постепенно менять из задания к заданию
    loss_fn: callable,
    get_boundary_timesteps: callable,
    get_xs_from_xt: callable,
    num_timesteps=1000,
    step_size=20,  # Указываем с каким интервалом берем шаги s и t.
):
    # Сэмплируем случайные шаги t для каждого элемента батча t ~ U[step_size-1, 999]
    assert num_timesteps == 1000
    num_intervals = num_timesteps // step_size

    index = torch.randint(
        1, num_intervals, (len(latents),), device=latents.device
    ).long()  # [1, num_intervals]
    t = step_size * index - 1
    s = torch.clamp(t - step_size, min=0)
    boundary_timesteps = get_boundary_timesteps(s, num_timesteps=num_timesteps)

    # Сэмплируем x_t
    noise = torch.randn_like(latents)
    x_t = q_sample(latents, t, scheduler, noise=noise)

    # with <YOUR CODE HERE>: # для реализации mixed-precision обучения в задании №4
    with torch.cuda.amp.autocast(dtype=torch.float16):
        noise_pred = unet(
            x_t.float(),
            t,
            encoder_hidden_states=prompt_embeds.float(),
        ).sample

    # Получаем оценку в граничной точке для x_t
    boundary_pred = ddim_solver_step(noise_pred, x_t, t, boundary_timesteps, scheduler)

    x_s = get_xs_from_xt(
        latents,
        x_t,
        t,
        s,
        scheduler,
        prompt_embeds=prompt_embeds,
        noise=noise,
    )

    # Предсказание "таргет моделью"
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
        target_noise_pred = unet(x_s, s, encoder_hidden_states=prompt_embeds).sample

    # Получаем оценку в граничной точке для x_s
    boundary_target = ddim_solver_step(target_noise_pred, x_s, s, boundary_timesteps, scheduler)
    loss = loss_fn(boundary_pred, boundary_target)

    return loss
In [16]:
import functools


def get_zero_boundary_timesteps(t, **kwargs):
    """
    Определяем шаги где будут срабатывать граничные условия.
    Для классических СM это t=0.
    """
    return torch.zeros_like(t)


ct_loss = functools.partial(
    cm_loss_template,
    loss_fn=mse_loss,
    get_boundary_timesteps=get_zero_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_naive,
)
assert cm_unet.active_adapter == "ct"

Задание №4¶

Эффективное обучение¶

Данное задание рассчитано на успешное выполнение на colab с бесплатной Tesla T4 c 15GB VRAM. Однако учить даже относительно небольшие T2I модели масштаба SD1.5 уже на коллабе в лоб проблематично.

Для этого нам нужно применить ряд инженерных техник, чтобы уместиться в данный бюджет и учиться за разумное время.

Список техник

  1. Включить gradient checkpointing для обучемой модели
  2. Добавить LoRA (Low Rank Adapters) адаптеры, чтобы учить не все веса, а только 10% добавочных весов
  3. Использовать gradient accumulation, чтобы делать итерацию обучения по бОльшему батчу, чем влезает по памяти
  4. Добавить mixed precision FP16/FP32 обучение модели для скорости. Обычно еще и память экономится, но в случае LoRA обучения + gradient checkpointing на память сильно влиять не должно, но зато станет быстрее.
  5. Мульти-GPU обучение - распределение вычислений по нескольким GPU.

1-2) Мы уже применили за вас выше

3-4) Предстоит реализовать вам самим в соотвествующей секции ниже

5 ) Недоступно, так как работаем на одной карточке

Обучающий цикл¶

Вам дан код обучения модель в полной точности (FP32) c батчом 8. К сожалению, на Tesla T4 мы не влезем по памяти. Поэтому в ячейке ниже вам нужно модифицировать цикл, чтобы он работал в mixed precision FP16 и добавить gradient accumulation.

Про реализацию mixed-precision в pytorch можно перейти по ссылке: Mixed-precision обучение

Обратите внимание: вам еще нужно добавить одну строчку кода в cm_loss_template в соответствующем плейсхолдере.

Замечание: В начале обучения значения лосса должны быть в окрестности 0.0007-0.001. Ничего страшного, что лосс не падает, для CM это нормально. В конце обучения лосс может доходить до 0.005-0.01

In [18]:
def train_loop(model, pipe, train_dataloader, optimizer, loss_fn, num_grad_accum=1):
    torch.cuda.empty_cache()

    scaler = torch.cuda.amp.GradScaler()

    for i, batch in enumerate(tqdm(train_dataloader)):
        with torch.cuda.amp.autocast(dtype=torch.float16):
            latents, prompt_embeds = prepare_batch(batch, pipe)
            loss = loss_fn(latents, prompt_embeds, model, pipe.scheduler) / num_grad_accum

        # Обновляем параметры
        scaler.scale(loss).backward()

        if (i + 1) % num_grad_accum == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        print(f"Loss: {loss.detach().item()}")
In [19]:
num_grad_accum = 2  # обновляем параметры каждые 2 шага

train_loop(cm_unet, pipe, train_dataloader, optimizer, ct_loss, num_grad_accum)
  0%|          | 0/625 [00:00<?, ?it/s]
Loss: 0.00037942116614431143
Loss: 0.0005064992001280189
Loss: 0.00041276204865425825
Loss: 0.0008958675898611546
Loss: 0.00046880001900717616
Loss: 0.0005985702155157924
Loss: 0.0004536816559266299
Loss: 0.00042376824421808124
Loss: 0.0007258296245709062
Loss: 0.00047782735782675445
Loss: 0.0005945760058239102
Loss: 0.0005281904013827443
Loss: 0.001687874086201191
Loss: 0.0008240551687777042
Loss: 0.0008860052330419421
Loss: 0.0008627125062048435
Loss: 0.0004061653162352741
Loss: 0.0004743822501040995
Loss: 0.0008747418178245425
Loss: 0.0005310914712026715
Loss: 0.0005723602953366935
Loss: 0.0008468653541058302
Loss: 0.0004579655360430479
Loss: 0.0005488618044182658
Loss: 0.0005751841817982495
Loss: 0.0004992288304492831
Loss: 0.0007894040900282562
Loss: 0.0007330112857744098
Loss: 0.0008934878278523684
Loss: 0.00043908506631851196
Loss: 0.000619879923760891
Loss: 0.0004753718967549503
Loss: 0.0006294205086305737
Loss: 0.0007459073094651103
Loss: 0.0005154737154953182
Loss: 0.0004189095343463123
Loss: 0.0014230990782380104
Loss: 0.0007570700254291296
Loss: 0.0007128705619834363
Loss: 0.0004862792557105422
Loss: 0.0008252396946772933
Loss: 0.000948379107285291
Loss: 0.0003448168863542378
Loss: 0.0014016738859936595
Loss: 0.00043688761070370674
Loss: 0.000760576338507235
Loss: 0.0004906122339889407
Loss: 0.001050944672897458
Loss: 0.00034305930603295565
Loss: 0.000333520642016083
Loss: 0.001346747623756528
Loss: 0.000518926652148366
Loss: 0.00038751529064029455
Loss: 0.0007347655482590199
Loss: 0.0006552053382620215
Loss: 0.0004693095979746431
Loss: 0.0005048643797636032
Loss: 0.0008245222270488739
Loss: 0.0006888275966048241
Loss: 0.00044082454405725
Loss: 0.0005636227433569729
Loss: 0.0005252123810350895
Loss: 0.0007635715301148593
Loss: 0.0005486704758368433
Loss: 0.0009442295995540917
Loss: 0.0006193040171638131
Loss: 0.0013048275141045451
Loss: 0.00045766972471028566
Loss: 0.0004775456618517637
Loss: 0.0005675374995917082
Loss: 0.004480988718569279
Loss: 0.001354208099655807
Loss: 0.0023091307375580072
Loss: 0.0030480027198791504
Loss: 0.0013057851465418935
Loss: 0.0008849604055285454
Loss: 0.0009611251298338175
Loss: 0.0006008940399624407
Loss: 0.0009718112414702773
Loss: 0.0006195436581037939
Loss: 0.0004468054394237697
Loss: 0.0007318559219129384
Loss: 0.000600254803430289
Loss: 0.0007269569323398173
Loss: 0.0008202603203244507
Loss: 0.0009326831204816699
Loss: 0.0010233625071123242
Loss: 0.0016146933194249868
Loss: 0.0006531896069645882
Loss: 0.0011280549224466085
Loss: 0.0010213699424639344
Loss: 0.0008086289744824171
Loss: 0.0008692542323842645
Loss: 0.0005729414988309145
Loss: 0.0006522737094201148
Loss: 0.0012347043957561255
Loss: 0.0014926702715456486
Loss: 0.0014064067509025335
Loss: 0.0016721924766898155
Loss: 0.0010140687227249146
Loss: 0.0007766926428303123
Loss: 0.0007161159301176667
Loss: 0.0016449993709102273
Loss: 0.0016016623703762889
Loss: 0.0009757575462572277
Loss: 0.0010558852227404714
Loss: 0.002975206822156906
Loss: 0.012377133592963219
Loss: 0.0015951453242450953
Loss: 0.0025169397704303265
Loss: 0.0017778994515538216
Loss: 0.0009804833680391312
Loss: 0.0023517082445323467
Loss: 0.0028763553127646446
Loss: 0.003936590161174536
Loss: 0.001311366562731564
Loss: 0.001428470597602427
Loss: 0.0016911597922444344
Loss: 0.0011702944757416844
Loss: 0.0014147096080705523
Loss: 0.002440669108182192
Loss: 0.0013251928612589836
Loss: 0.004199718590825796
Loss: 0.0011574899544939399
Loss: 0.0014455909840762615
Loss: 0.006965372711420059
Loss: 0.0014533543726429343
Loss: 0.0020802379585802555
Loss: 0.005591132678091526
Loss: 0.0007086016703397036
Loss: 0.0014595562824979424
Loss: 0.003932422958314419
Loss: 0.001268944120965898
Loss: 0.001382922986522317
Loss: 0.003330608131363988
Loss: 0.004483602941036224
Loss: 0.003217090852558613
Loss: 0.0020847718697041273
Loss: 0.0023813711013644934
Loss: 0.0023249583318829536
Loss: 0.0025102400686591864
Loss: 0.0014276099391281605
Loss: 0.0010522021912038326
Loss: 0.0021666488610208035
Loss: 0.0010541814845055342
Loss: 0.0013616065261885524
Loss: 0.00223144399933517
Loss: 0.0017787872347980738
Loss: 0.001991346012800932
Loss: 0.0024860501289367676
Loss: 0.0017141818534582853
Loss: 0.001713483827188611
Loss: 0.002013132907450199
Loss: 0.0016368308570235968
Loss: 0.0020724921487271786
Loss: 0.001502691418863833
Loss: 0.0018795530777424574
Loss: 0.0007520514191128314
Loss: 0.0007461420027539134
Loss: 0.0022452983539551497
Loss: 0.002604313427582383
Loss: 0.0008028405718505383
Loss: 0.0025118214543908834
Loss: 0.0013362554600462317
Loss: 0.0009904057951644063
Loss: 0.0026333308778703213
Loss: 0.0013375040143728256
Loss: 0.0013061071513220668
Loss: 0.0012211732100695372
Loss: 0.0016559758223593235
Loss: 0.0007818201556801796
Loss: 0.0007729750941507518
Loss: 0.0008355433237738907
Loss: 0.0006848564371466637
Loss: 0.0009499641600996256
Loss: 0.0007108630961738527
Loss: 0.0011214565020054579
Loss: 0.00048189060180447996
Loss: 0.0007489272393286228
Loss: 0.000885029265191406
Loss: 0.0008287807577289641
Loss: 0.000563493580557406
Loss: 0.0005383663810789585
Loss: 0.0023214269895106554
Loss: 0.0020749422255903482
Loss: 0.0008023543050512671
Loss: 0.0017955926014110446
Loss: 0.000757447094656527
Loss: 0.0005793230957351625
Loss: 0.0006887734634801745
Loss: 0.0009876531548798084
Loss: 0.0005812561721540987
Loss: 0.00046745891449972987
Loss: 0.0010823803022503853
Loss: 0.0011851361487060785
Loss: 0.0006656574551016092
Loss: 0.0006380220875144005
Loss: 0.0005546664469875395
Loss: 0.0007966226548887789
Loss: 0.0006024678586982191
Loss: 0.0006565562216565013
Loss: 0.0010175753850489855
Loss: 0.001282352488487959
Loss: 0.0007072788430377841
Loss: 0.0015107081271708012
Loss: 0.0012874935055151582
Loss: 0.0008330004056915641
Loss: 0.000535622937604785
Loss: 0.0006450955988839269
Loss: 0.0006074419361539185
Loss: 0.0007074175518937409
Loss: 0.0007562537211924791
Loss: 0.0009953570552170277
Loss: 0.0014050425961613655
Loss: 0.0004630361800082028
Loss: 0.0011367471888661385
Loss: 0.0018684373935684562
Loss: 0.0012699714861810207
Loss: 0.00047859508777037263
Loss: 0.0009107966325245798
Loss: 0.0013107287231832743
Loss: 0.001922330935485661
Loss: 0.0014184056781232357
Loss: 0.0007522629457525909
Loss: 0.000422043027356267
Loss: 0.0010425536893308163
Loss: 0.0011007111752405763
Loss: 0.0011708419770002365
Loss: 0.0011846733978018165
Loss: 0.0006315461359918118
Loss: 0.0012706996640190482
Loss: 0.0014523772988468409
Loss: 0.0006138435564935207
Loss: 0.0017626138869673014
Loss: 0.0005771452561020851
Loss: 0.0010470845736563206
Loss: 0.0020099482499063015
Loss: 0.0007128794677555561
Loss: 0.0008252130355685949
Loss: 0.001020087394863367
Loss: 0.0009030108922161162
Loss: 0.0007460082415491343
Loss: 0.0006069971714168787
Loss: 0.0012493666727095842
Loss: 0.0009998545283451676
Loss: 0.0005000063101761043
Loss: 0.0013536994811147451
Loss: 0.0009585645166225731
Loss: 0.0008933030185289681
Loss: 0.0005902259144932032
Loss: 0.0023559462279081345
Loss: 0.0007550917216576636
Loss: 0.0013092129956930876
Loss: 0.0005594904650934041
Loss: 0.0008394026081077754
Loss: 0.0014222621684893966
Loss: 0.0010701077990233898
Loss: 0.0006709058070555329
Loss: 0.0014043166302144527
Loss: 0.0015168897807598114
Loss: 0.0007551686139777303
Loss: 0.0010281822178512812
Loss: 0.0007850765250623226
Loss: 0.000889840186573565
Loss: 0.0008110835333354771
Loss: 0.0009823492728173733
Loss: 0.0005367578705772758
Loss: 0.0008935442892834544
Loss: 0.0010250592604279518
Loss: 0.0007431853446178138
Loss: 0.0007468818221241236
Loss: 0.0016860202886164188
Loss: 0.0006869430071674287
Loss: 0.0006877119303680956
Loss: 0.0028378237038850784
Loss: 0.0008705396903678775
Loss: 0.0018622581847012043
Loss: 0.0016560859512537718
Loss: 0.0006787670427002013
Loss: 0.001796035561710596
Loss: 0.0012967600487172604
Loss: 0.0009896749397739768
Loss: 0.0019619313534349203
Loss: 0.0014837800990790129
Loss: 0.0011329748667776585
Loss: 0.0012308049481362104
Loss: 0.0011867510620504618
Loss: 0.0009477960411459208
Loss: 0.0011073565110564232
Loss: 0.00047491080476902425
Loss: 0.0004928068956360221
Loss: 0.0007312442758120596
Loss: 0.0009427835466340184
Loss: 0.0009006276377476752
Loss: 0.0012641862267628312
Loss: 0.0019818381406366825
Loss: 0.0012553343549370766
Loss: 0.0012674556346610188
Loss: 0.0006022984161973
Loss: 0.0021270415745675564
Loss: 0.001260775257833302
Loss: 0.0009893679525703192
Loss: 0.0016898381290957332
Loss: 0.0006888847565278411
Loss: 0.0014245009515434504
Loss: 0.0007137987995520234
Loss: 0.0006693258765153587
Loss: 0.0010441727936267853
Loss: 0.0015700546791777015
Loss: 0.0009292969480156898
Loss: 0.0007500495994463563
Loss: 0.0008475390495732427
Loss: 0.001448027789592743
Loss: 0.001073551713488996
Loss: 0.0020768537651747465
Loss: 0.001407198142260313
Loss: 0.0010403033811599016
Loss: 0.0006459858268499374
Loss: 0.0015814948128536344
Loss: 0.0010547223500907421
Loss: 0.0015277417842298746
Loss: 0.0011319030309095979
Loss: 0.0017115375958383083
Loss: 0.0014908439479768276
Loss: 0.001057351822964847
Loss: 0.0009869931964203715
Loss: 0.0018658344633877277
Loss: 0.0008516835514456034
Loss: 0.0009955026907846332
Loss: 0.0018090622033923864
Loss: 0.0011853792238980532
Loss: 0.00108595029450953
Loss: 0.0012533128028735518
Loss: 0.0013152830069884658
Loss: 0.0017013889737427235
Loss: 0.0011044272687286139
Loss: 0.0009303900296799839
Loss: 0.0014018246438354254
Loss: 0.0015869715716689825
Loss: 0.001160049345344305
Loss: 0.0012527592480182648
Loss: 0.00070428685285151
Loss: 0.0013125402620062232
Loss: 0.0021101143211126328
Loss: 0.0010340107837691903
Loss: 0.0011255182325839996
Loss: 0.0011866686400026083
Loss: 0.00203328556381166
Loss: 0.0013247504830360413
Loss: 0.001511218724772334
Loss: 0.00161873793695122
Loss: 0.0007672483334317803
Loss: 0.0011965546291321516
Loss: 0.0010649219620972872
Loss: 0.0021924981847405434
Loss: 0.0018424775917083025
Loss: 0.0015629628906026483
Loss: 0.0005752117140218616
Loss: 0.0017677939031273127
Loss: 0.0010223608696833253
Loss: 0.0015683624660596251
Loss: 0.0013369601219892502
Loss: 0.0012187148677185178
Loss: 0.0013127631973475218
Loss: 0.0021983052138239145
Loss: 0.0007943587261252105
Loss: 0.0018427509348839521
Loss: 0.0031714406795799732
Loss: 0.002107131527736783
Loss: 0.0013186594005674124
Loss: 0.001254415838047862
Loss: 0.0008977922843769193
Loss: 0.0019465356599539518
Loss: 0.0013357085408642888
Loss: 0.0013917243340983987
Loss: 0.0011719500180333853
Loss: 0.002524951007217169
Loss: 0.0005994065431877971
Loss: 0.0009107952937483788
Loss: 0.0011468140874058008
Loss: 0.0010048914700746536
Loss: 0.001208997331559658
Loss: 0.0009585996740497649
Loss: 0.0017437248025089502
Loss: 0.000889518007170409
Loss: 0.0009529950912110507
Loss: 0.0008546350873075426
Loss: 0.001270101871341467
Loss: 0.0016016976442188025
Loss: 0.0014776111347600818
Loss: 0.0013454521540552378
Loss: 0.001419242238625884
Loss: 0.000888891750946641
Loss: 0.0011973816435784101
Loss: 0.0009585467050783336
Loss: 0.001834752387367189
Loss: 0.0011517828097566962
Loss: 0.0010255178203806281
Loss: 0.0009129823301918805
Loss: 0.0009568651439622045
Loss: 0.003211995819583535
Loss: 0.0011029181769117713
Loss: 0.0017177518457174301
Loss: 0.0013423544587567449
Loss: 0.0016658417880535126
Loss: 0.001043348340317607
Loss: 0.0007647225284017622
Loss: 0.0016419119201600552
Loss: 0.001291297608986497
Loss: 0.0007930601132102311
Loss: 0.0010712259681895375
Loss: 0.0009605751256458461
Loss: 0.0010525969555601478
Loss: 0.00116306624840945
Loss: 0.002070025308057666
Loss: 0.0017251630779355764
Loss: 0.0011048103915527463
Loss: 0.0016968096606433392
Loss: 0.002608741167932749
Loss: 0.0006307986914180219
Loss: 0.0007329047657549381
Loss: 0.0010837421286851168
Loss: 0.002312293741852045
Loss: 0.0008948363829404116
Loss: 0.0005451926263049245
Loss: 0.0014810776337981224
Loss: 0.0007154321065172553
Loss: 0.0006251891609281301
Loss: 0.0015294752083718777
Loss: 0.0011080572148784995
Loss: 0.0011535331141203642
Loss: 0.000980229815468192
Loss: 0.001611695159226656
Loss: 0.0011532744392752647
Loss: 0.0026058549992740154
Loss: 0.0013095736503601074
Loss: 0.0005014284979552031
Loss: 0.00201485026627779
Loss: 0.0018339725211262703
Loss: 0.0012314997147768736
Loss: 0.0007580803940072656
Loss: 0.0015576020814478397
Loss: 0.0009253334137611091
Loss: 0.0019092496950179338
Loss: 0.0009791709017008543
Loss: 0.0011775689199566841
Loss: 0.0013838681625202298
Loss: 0.0016961873043328524
Loss: 0.0007651025662198663
Loss: 0.0016154772602021694
Loss: 0.00031895851134322584
Loss: 0.0014226112980395555
Loss: 0.003908277489244938
Loss: 0.001177951693534851
Loss: 0.0010679435217753053
Loss: 0.001458268496207893
Loss: 0.0007180237444117665
Loss: 0.001987336901947856
Loss: 0.000967807718552649
Loss: 0.0025338579434901476
Loss: 0.001234889728948474
Loss: 0.0022540781646966934
Loss: 0.0016532859299331903
Loss: 0.000884941837284714
Loss: 0.0012026582844555378
Loss: 0.002077655866742134
Loss: 0.0008688797242939472
Loss: 0.001374588580802083
Loss: 0.0009795373771339655
Loss: 0.0009474847465753555
Loss: 0.0015431537758558989
Loss: 0.0014992763753980398
Loss: 0.00224619940854609
Loss: 0.0019898926839232445
Loss: 0.0011886644642800093
Loss: 0.001195912016555667
Loss: 0.0009470513323321939
Loss: 0.0021574359852820635
Loss: 0.0017166482284665108
Loss: 0.0023729419335722923
Loss: 0.0015617401804775
Loss: 0.002103858394548297
Loss: 0.0019015774596482515
Loss: 0.0025202487595379353
Loss: 0.001155729521997273
Loss: 0.001037675654515624
Loss: 0.0013360804878175259
Loss: 0.002696119947358966
Loss: 0.0014535182854160666
Loss: 0.004336210899055004
Loss: 0.001846088794991374
Loss: 0.00272024841979146
Loss: 0.0013463783543556929
Loss: 0.0015612333081662655
Loss: 0.0009505340713076293
Loss: 0.0013244056608527899
Loss: 0.0009633367299102247
Loss: 0.0011195067781955004
Loss: 0.0011987247271463275
Loss: 0.002025905065238476
Loss: 0.0010244036093354225
Loss: 0.001045489450916648
Loss: 0.000772233703173697
Loss: 0.0015264117391780019
Loss: 0.0008213156834244728
Loss: 0.0013427824014797807
Loss: 0.0017392404843121767
Loss: 0.0005934350774623454
Loss: 0.00134235096629709
Loss: 0.0005927301826886833
Loss: 0.0013859393075108528
Loss: 0.001019574236124754
Loss: 0.0007304061437025666
Loss: 0.0006748999003320932
Loss: 0.001297136303037405
Loss: 0.0013925565872341394
Loss: 0.0006879764841869473
Loss: 0.0007899506017565727
Loss: 0.0008254270069301128
Loss: 0.0017544485162943602
Loss: 0.0013860096223652363
Loss: 0.000715556787326932
Loss: 0.0013792227255180478
Loss: 0.0010635866783559322
Loss: 0.0009579313336871564
Loss: 0.0008762564975768328
Loss: 0.0012048776261508465
Loss: 0.0010751961963251233
Loss: 0.001382380723953247
Loss: 0.003072496969252825
Loss: 0.0009861232247203588
Loss: 0.0007311536464840174
Loss: 0.0009452502126805484
Loss: 0.0010670819319784641
Loss: 0.0010411642724648118
Loss: 0.0024563304614275694
Loss: 0.0006016616243869066
Loss: 0.0011021350510418415
Loss: 0.0013145486591383815
Loss: 0.0010332380188629031
Loss: 0.0008215782581828535
Loss: 0.0012557602021843195
Loss: 0.001005598227493465
Loss: 0.001854797126725316
Loss: 0.001681183697655797
Loss: 0.0010543595999479294
Loss: 0.0017493953928351402
Loss: 0.0013288147747516632
Loss: 0.0010074828751385212
Loss: 0.0011979619739577174
Loss: 0.0007949782302603126
Loss: 0.0008158296695910394
Loss: 0.001095445710234344
Loss: 0.0015405741287395358
Loss: 0.0006520182359963655
Loss: 0.0015143337659537792
Loss: 0.0009323414415121078
Loss: 0.0014065701980143785
Loss: 0.0026730522513389587
Loss: 0.0009179461630992591
Loss: 0.001791340415365994
Loss: 0.0018697405466809869
Loss: 0.0010157522046938539
Loss: 0.0013694125227630138
Loss: 0.0016324010211974382
Loss: 0.0011204960756003857
Loss: 0.0022372305393218994
Loss: 0.0017677509458735585
Loss: 0.0015592731069773436
Loss: 0.0017771257553249598
Loss: 0.001563400262966752
Loss: 0.0006747982697561383
Loss: 0.0021602315828204155
Loss: 0.003979039844125509
Loss: 0.0011788563570007682
Loss: 0.0013144396943971515
Loss: 0.001798801589757204
Loss: 0.0020310126710683107
Loss: 0.001158342813141644
Loss: 0.0022471256088465452
Loss: 0.002206086879596114
Loss: 0.00145438383333385
Loss: 0.001231637317687273
Loss: 0.000668485532514751
Loss: 0.0017073522321879864
Loss: 0.0017827639821916819
Loss: 0.0007811461109668016
Loss: 0.0011105649173259735
Loss: 0.0018329150043427944
Loss: 0.0009835632517933846
Loss: 0.0018245907267555594
Loss: 0.0013090935535728931
Loss: 0.0019736108370125294
Loss: 0.0009373706416226923
Loss: 0.0013020422775298357
Loss: 0.002027781680226326
Loss: 0.0012231569271534681
Loss: 0.0012198350159451365
Loss: 0.001646938151679933
Loss: 0.001554321963340044
Loss: 0.0009335360373370349
Loss: 0.0011970591731369495
Loss: 0.0011865491978824139
Loss: 0.0014073492493480444
Loss: 0.0008660672465339303
Loss: 0.0011291594710201025
Loss: 0.0009769469033926725
Loss: 0.0016495260642841458
Loss: 0.0011919370153918862
Loss: 0.0014173786621540785
Loss: 0.0006484888726845384
Loss: 0.0008525124285370111
Loss: 0.0009442422888241708
Loss: 0.0006571448757313192
Loss: 0.001279762596823275
Loss: 0.0009764874121174216
Loss: 0.0007751007797196507
Loss: 0.000769660749938339
Loss: 0.0015980221796780825
Loss: 0.0012362840352579951
Loss: 0.0007637885282747447
Loss: 0.0014020904200151563
Loss: 0.0009113442501984537
Loss: 0.0008350464049726725
Loss: 0.0008936412050388753
In [ ]:
# torch.save(cm_unet.state_dict(), "cm_unet.pt")
In [ ]:
# cm_unet.load_state_dict(torch.load("cm_unet.pt"))
Out[ ]:
<All keys matched successfully>

Задание 5¶

Генерация с помощью обученной консистенси модели¶

Настало время погенерировать картинки с помощью нашей модели. Напомним, что мы не можем для консистенси моделей использовать DDIM и другие классические солверы для диффузии. Нам нужен специальный сэмплер для CM, который схематично изображен на картинке ниже:

No description has been provided for this image

Чуть более формально:

$x_{t_n} \sim {N}(0, I)$

$for\ t_i \in [t_n, ..., t_1]:$

  • $\epsilon \leftarrow unet(x_{t_i})$

  • $x_0 \leftarrow DDIM(\epsilon, x_{t_i}, t_i, 0)$

  • $x_{t_{i-1}} \leftarrow q(x_{t_{i-1}} | x_0)$

Classifier-free guidance (CFG)

Также вам надо реализовать поддержку CFG в CM сэмплирование. Вспомним формулу:

$\epsilon_w = {\color{blue}{\epsilon_{uncond}}} + w \cdot (\epsilon_{cond} - \epsilon_{uncond})$, где $w \geq 1$

Обратим внимание, что режим "без гайденса" соотвествует $w = 1$, что немного контринтуитивно, но в большинстве реализаций будет встречаться именно такой вид этой формулы.

In [17]:
@torch.no_grad()
def consistency_sampling(
    pipe, prompt, num_inference_steps=4, generator=None, num_images_per_prompt=4, guidance_scale=1
):
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)

    device = pipe._execution_device

    # Извлекаем эмбеды из текстовых промптов. Реализуйте вызов pipe.encode_prompt
    do_classifier_free_guidance = guidance_scale > 1
    prompt_embeds, null_prompt_embeds = pipe.encode_prompt(
        prompt,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        do_classifier_free_guidance=do_classifier_free_guidance,
    )
    assert prompt_embeds.dtype == torch.float16

    # Настраиваем параметры scheduler-a
    assert pipe.scheduler.config["timestep_spacing"] == "trailing"
    pipe.scheduler.set_timesteps(num_inference_steps)

    # Создаем батч латентов из N(0,I)
    latents = torch.randn(
        (
            batch_size * num_images_per_prompt,
            pipe.unet.in_channels,
            pipe.unet.sample_size,
            pipe.unet.sample_size,
        ),
        device=device,
        generator=generator,
        dtype=torch.float16,
    )

    for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
        t = torch.tensor([t] * len(latents)).to(device)
        zero_t = torch.tensor([0] * len(latents)).to(device)

        cond_noise_pred = pipe.unet(
            latents,
            t,
            encoder_hidden_states=prompt_embeds,
        ).sample

        if do_classifier_free_guidance:
            uncond_noise_pred = pipe.unet(
                latents, t, encoder_hidden_states=null_prompt_embeds
            ).sample
            noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
        else:
            noise_pred = cond_noise_pred

        # Получаем x_0 оценку из x_t
        x_0 = ddim_solver_step(noise_pred, latents, t, zero_t, scheduler=pipe.scheduler)

        if i + 1 < num_inference_steps:
            # Переход на следующий шаг
            s = pipe.scheduler.timesteps[i + 1]
            s = torch.tensor([s] * len(latents)).to(device)

            noise = torch.normal(mean=torch.zeros_like(latents), generator=generator)
            latents = q_sample(x_0, s, pipe.scheduler, noise=noise)
        else:
            # Последний шаг
            latents = x_0

        latents = latents.half()

    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
    do_denormalize = [True] * image.shape[0]
    image = pipe.image_processor.postprocess(
        image, output_type="pil", do_denormalize=do_denormalize
    )
    return image

Попробуем сгененировать что-то нашей моделью. Можно поиграться с разными сидами и гайденс скейлами.

Референс, что примерно должно получиться на этом этапе для guidance_scale=2. Как видите, картинки стали почетче, но пока все еще так себе.

img

In [21]:
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == "ct"

generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 2

# Заменяем генерацию пайплайном на наше сэмплирование.
images = consistency_sampling(
    pipe=pipe,
    prompt="A sad puppy with large eyes",
    generator=generator,
    num_images_per_prompt=4,
    guidance_scale=guidance_scale,
)

visualize_images(images)
/root/miniconda3/envs/pytorch-env/lib/python3.10/site-packages/peft/tuners/lora/model.py:375: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.
  return getattr(self.model, name)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
In [ ]:
# !fuser -v /dev/nvidia* -k

Consistency Distillation¶

Задание №6¶

Теперь давайте попробуем перейти к постановке дистилляции, где шаг из $x_t$ в $x_s$ будет делаться не аналитически, а c помощью модели учителя.

$\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$

$\mathbf{x}_s = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, s)$

Замечание: В text-to-image генерации classifier-free guidance (CFG) играет очень важную роль для получения хорошего качества с помощью диффузии. CFG меняет траектории ODE и раз нам он важен, то давайте и дистиллировать траектории с CFG.

Поэтому для получения точки $\mathbf{x}_{s}$ мы будем использовать шаг учителя с CFG. Это важное отличие от CT сеттинга - там мы не можем моделировать гайденс.

In [20]:
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32

# Добавляем новые LoRA адаптеры для CD модели
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")

# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
In [21]:
@torch.no_grad()
def get_xs_from_xt_with_teacher(
    x_0,
    x_t,
    t,
    s,  # Not all arguments may be needed
    scheduler,
    prompt_embeds,
    teacher_unet,
    guidance_scale,
    **kwargs
):
    # Делаем предсказание учителем в кондишион случае: подаем эмбеды текста
    cond_noise_pred = teacher_unet(
        sample=x_t, timestep=t, encoder_hidden_states=prompt_embeds
    ).sample

    # Для CFG нам нужно делать предсказания в unconditional случае.
    # Для T2I моделей, мы будем это моделировать предсказаниями для пустого промпта ""
    # Извлечем эмбеды из пустого промпта и размножить их до размера батча
    uncond_input_ids = pipe.tokenizer(
        [""], return_tensors="pt", padding="max_length", max_length=77
    ).input_ids.to("cuda")

    uncond_prompt_embeds = pipe.text_encoder(uncond_input_ids)[0].expand(*prompt_embeds.shape)
    # Затем прогоняем модель для пустых промптов
    uncond_noise_pred = teacher_unet(
        sample=x_t,
        timestep=t,
        encoder_hidden_states=uncond_prompt_embeds,
    ).sample

    # Применяем CFG формулу и получаем итоговый предикт учителя
    noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
    # noise_pred = (1 + guidance_scale) * cond_noise_pred - guidance_scale * uncond_noise_pred

    # Получаем x_s из x_t
    x_s = ddim_solver_step(noise_pred, x_t, t, s, scheduler)

    return x_s


# Сразу зададим внутрь модель учителя и guidance_scale
get_xs_from_xt_with_teacher = functools.partial(
    get_xs_from_xt_with_teacher,
    teacher_unet=teacher_unet,
    guidance_scale=7.5,
)

Еще, как показано в работе Improved Techniques for Training Consistency Models. L2 лосс не самый оптимальный выбор для консистенси моделей. Давайте в CD обучении также заменим MSE лосс на pseudo-huber лосс из статьи.

In [22]:
def pseudo_huber_loss(x: torch.Tensor, y: torch.Tensor, c=0.001):
    loss = torch.mean(torch.sqrt(torch.square(x - y) + c**2) - c)
    return loss
In [23]:
cd_loss = functools.partial(
    cm_loss_template,
    loss_fn=pseudo_huber_loss,
    get_boundary_timesteps=get_zero_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_with_teacher,
)

assert cm_unet.active_adapter == "cd"

Теперь обучим модель в CD режиме

In [24]:
num_grad_accum = 2  # обновляем параметры каждые 2 шага

train_loop(cm_unet, pipe, train_dataloader, optimizer, cd_loss, num_grad_accum)
  0%|          | 0/625 [00:00<?, ?it/s]
Loss: 0.009496292099356651
Loss: 0.011937526986002922
Loss: 0.01138567179441452
Loss: 0.015678219497203827
Loss: 0.009431742131710052
Loss: 0.010772095993161201
Loss: 0.016768421977758408
Loss: 0.012787067331373692
Loss: 0.012087555602192879
Loss: 0.017596885561943054
Loss: 0.009169336408376694
Loss: 0.012779573909938335
Loss: 0.011597014963626862
Loss: 0.011395310051739216
Loss: 0.008672960102558136
Loss: 0.009967935271561146
Loss: 0.01345849595963955
Loss: 0.010138191282749176
Loss: 0.011719721369445324
Loss: 0.011032583191990852
Loss: 0.013462868519127369
Loss: 0.007578674703836441
Loss: 0.013694102875888348
Loss: 0.016981203109025955
Loss: 0.015872275456786156
Loss: 0.012384796515107155
Loss: 0.015420867130160332
Loss: 0.018556609749794006
Loss: 0.008531913161277771
Loss: 0.0081110168248415
Loss: 0.012532942928373814
Loss: 0.008422580547630787
Loss: 0.01047995314002037
Loss: 0.012827962636947632
Loss: 0.011391503736376762
Loss: 0.01443835161626339
Loss: 0.012678487226366997
Loss: 0.012677878141403198
Loss: 0.013434009626507759
Loss: 0.01593075878918171
Loss: 0.015876853838562965
Loss: 0.01336511317640543
Loss: 0.0109160291031003
Loss: 0.016093868762254715
Loss: 0.018056068569421768
Loss: 0.014975743368268013
Loss: 0.013509530574083328
Loss: 0.013921107165515423
Loss: 0.01753494143486023
Loss: 0.011128033511340618
Loss: 0.01605379767715931
Loss: 0.013949436135590076
Loss: 0.019396783784031868
Loss: 0.01391231082379818
Loss: 0.011289631016552448
Loss: 0.014366846531629562
Loss: 0.017452500760555267
Loss: 0.01670077256858349
Loss: 0.014370341785252094
Loss: 0.009669305756688118
Loss: 0.018039245158433914
Loss: 0.015545105561614037
Loss: 0.01798805594444275
Loss: 0.020122313871979713
Loss: 0.022047407925128937
Loss: 0.014256610535085201
Loss: 0.022882739081978798
Loss: 0.010917379520833492
Loss: 0.020264677703380585
Loss: 0.021845266222953796
Loss: 0.011254949495196342
Loss: 0.019297216087579727
Loss: 0.016736852005124092
Loss: 0.012872754596173763
Loss: 0.012026213109493256
Loss: 0.02280844748020172
Loss: 0.022811302915215492
Loss: 0.027594469487667084
Loss: 0.0183707345277071
Loss: 0.03144780918955803
Loss: 0.018408607691526413
Loss: 0.025689981877803802
Loss: 0.03236237168312073
Loss: 0.025275297462940216
Loss: 0.025577928870916367
Loss: 0.01924975961446762
Loss: 0.01932159811258316
Loss: 0.02623400092124939
Loss: 0.03247181326150894
Loss: 0.02669796720147133
Loss: 0.01757189631462097
Loss: 0.030440235510468483
Loss: 0.027177581563591957
Loss: 0.026035990566015244
Loss: 0.0243036188185215
Loss: 0.022974105551838875
Loss: 0.024366341531276703
Loss: 0.01629873737692833
Loss: 0.013320215046405792
Loss: 0.02426120638847351
Loss: 0.0341351293027401
Loss: 0.026354236528277397
Loss: 0.03175276145339012
Loss: 0.03364691883325577
Loss: 0.01664678007364273
Loss: 0.03079746663570404
Loss: 0.031287822872400284
Loss: 0.03689461573958397
Loss: 0.036125171929597855
Loss: 0.02518559619784355
Loss: 0.03783692419528961
Loss: 0.03708384931087494
Loss: 0.017160164192318916
Loss: 0.02529047429561615
Loss: 0.019349735230207443
Loss: 0.020972581580281258
Loss: 0.0210052952170372
Loss: 0.026679249480366707
Loss: 0.02700812742114067
Loss: 0.028842540457844734
Loss: 0.03562235087156296
Loss: 0.04374111443758011
Loss: 0.03097364492714405
Loss: 0.012074470520019531
Loss: 0.02567477338016033
Loss: 0.018477700650691986
Loss: 0.01651831530034542
Loss: 0.016548018902540207
Loss: 0.022041406482458115
Loss: 0.04033474624156952
Loss: 0.02967374213039875
Loss: 0.024199943989515305
Loss: 0.019739586859941483
Loss: 0.0354057252407074
Loss: 0.013216789811849594
Loss: 0.03076568804681301
Loss: 0.025727862492203712
Loss: 0.022994527593255043
Loss: 0.026900198310613632
Loss: 0.02344381809234619
Loss: 0.01690497249364853
Loss: 0.021254222840070724
Loss: 0.02345280349254608
Loss: 0.02174125239253044
Loss: 0.009578527882695198
Loss: 0.01897246576845646
Loss: 0.03501308709383011
Loss: 0.03864634037017822
Loss: 0.021854937076568604
Loss: 0.02508222497999668
Loss: 0.011814702302217484
Loss: 0.021309832111001015
Loss: 0.01157014723867178
Loss: 0.04338601976633072
Loss: 0.013866527006030083
Loss: 0.04070926457643509
Loss: 0.01488990243524313
Loss: 0.015293469652533531
Loss: 0.013237053528428078
Loss: 0.027621319517493248
Loss: 0.016671590507030487
Loss: 0.028909889981150627
Loss: 0.01657666452229023
Loss: 0.02959974855184555
Loss: 0.012124164961278439
Loss: 0.029925191774964333
Loss: 0.03146837651729584
Loss: 0.023782871663570404
Loss: 0.040163569152355194
Loss: 0.036783743649721146
Loss: 0.01889895647764206
Loss: 0.013891350477933884
Loss: 0.02183867245912552
Loss: 0.027528773993253708
Loss: 0.02591150812804699
Loss: 0.03047752007842064
Loss: 0.021378343924880028
Loss: 0.019397836178541183
Loss: 0.0239555723965168
Loss: 0.03431772440671921
Loss: 0.017908576875925064
Loss: 0.02693500742316246
Loss: 0.02115318737924099
Loss: 0.03414390608668327
Loss: 0.019907239824533463
Loss: 0.021337632089853287
Loss: 0.03761674836277962
Loss: 0.020723192021250725
Loss: 0.016744276508688927
Loss: 0.02132866531610489
Loss: 0.024194566532969475
Loss: 0.011753836646676064
Loss: 0.017650388181209564
Loss: 0.039314355701208115
Loss: 0.0260311271995306
Loss: 0.019218210130929947
Loss: 0.016347380355000496
Loss: 0.019863583147525787
Loss: 0.029556449502706528
Loss: 0.023201841861009598
Loss: 0.028976373374462128
Loss: 0.03007388487458229
Loss: 0.017817983403801918
Loss: 0.02871834486722946
Loss: 0.03152109682559967
Loss: 0.025081973522901535
Loss: 0.02826644666492939
Loss: 0.022341718897223473
Loss: 0.021324416622519493
Loss: 0.019095992669463158
Loss: 0.0325055755674839
Loss: 0.04651474207639694
Loss: 0.020459052175283432
Loss: 0.045019619166851044
Loss: 0.010087361559271812
Loss: 0.0239616297185421
Loss: 0.03117445483803749
Loss: 0.027433453127741814
Loss: 0.019093353301286697
Loss: 0.013201046735048294
Loss: 0.024706676602363586
Loss: 0.022092612460255623
Loss: 0.02983454056084156
Loss: 0.015126354061067104
Loss: 0.018014488741755486
Loss: 0.028299586847424507
Loss: 0.011203239671885967
Loss: 0.024925164878368378
Loss: 0.029286377131938934
Loss: 0.015891991555690765
Loss: 0.021685559302568436
Loss: 0.0263700932264328
Loss: 0.01141943410038948
Loss: 0.014658866450190544
Loss: 0.018147766590118408
Loss: 0.012419617734849453
Loss: 0.01503115613013506
Loss: 0.032400257885456085
Loss: 0.014440803788602352
Loss: 0.020771950483322144
Loss: 0.011454792693257332
Loss: 0.01948883943259716
Loss: 0.027998320758342743
Loss: 0.012758666649460793
Loss: 0.028439035639166832
Loss: 0.009769264608621597
Loss: 0.01960146240890026
Loss: 0.02903605066239834
Loss: 0.030036015436053276
Loss: 0.0213596411049366
Loss: 0.02497616782784462
Loss: 0.01905214600265026
Loss: 0.0327293835580349
Loss: 0.03161487728357315
Loss: 0.01919790357351303
Loss: 0.03405335545539856
Loss: 0.01812189444899559
Loss: 0.03108360804617405
Loss: 0.021709628403186798
Loss: 0.01338990405201912
Loss: 0.03325280547142029
Loss: 0.03440108522772789
Loss: 0.02291320264339447
Loss: 0.03339104354381561
Loss: 0.017286982387304306
Loss: 0.026906754821538925
Loss: 0.020866746082901955
Loss: 0.02603893354535103
Loss: 0.017456278204917908
Loss: 0.009823089465498924
Loss: 0.027211233973503113
Loss: 0.014216199517250061
Loss: 0.03865071386098862
Loss: 0.033714476972818375
Loss: 0.012833082117140293
Loss: 0.01713794656097889
Loss: 0.02267385646700859
Loss: 0.02660496160387993
Loss: 0.015578197315335274
Loss: 0.02379683591425419
Loss: 0.024973252788186073
Loss: 0.02699943073093891
Loss: 0.023325879126787186
Loss: 0.021142350509762764
Loss: 0.018284672871232033
Loss: 0.031939249485731125
Loss: 0.019638104364275932
Loss: 0.023888016119599342
Loss: 0.0175599567592144
Loss: 0.020038485527038574
Loss: 0.03570227324962616
Loss: 0.021797675639390945
Loss: 0.019452963024377823
Loss: 0.023516086861491203
Loss: 0.024924416095018387
Loss: 0.03526606783270836
Loss: 0.030382489785552025
Loss: 0.02078934758901596
Loss: 0.016493599861860275
Loss: 0.022489190101623535
Loss: 0.020600879564881325
Loss: 0.022667860612273216
Loss: 0.025904497131705284
Loss: 0.018586423248052597
Loss: 0.02097979746758938
Loss: 0.026850156486034393
Loss: 0.029706312343478203
Loss: 0.028161903843283653
Loss: 0.023522838950157166
Loss: 0.02759244665503502
Loss: 0.013026240281760693
Loss: 0.007144790608435869
Loss: 0.04768797382712364
Loss: 0.015131542459130287
Loss: 0.02647995576262474
Loss: 0.020290199667215347
Loss: 0.021466249600052834
Loss: 0.019636791199445724
Loss: 0.02598472312092781
Loss: 0.017441246658563614
Loss: 0.020738672465085983
Loss: 0.019031837582588196
Loss: 0.01593819446861744
Loss: 0.03732496127486229
Loss: 0.031962551176548004
Loss: 0.02812816947698593
Loss: 0.019592955708503723
Loss: 0.03974433243274689
Loss: 0.006029962562024593
Loss: 0.007793230935931206
Loss: 0.028224937617778778
Loss: 0.019553285092115402
Loss: 0.008423061110079288
Loss: 0.02248038537800312
Loss: 0.023505384102463722
Loss: 0.02730429358780384
Loss: 0.030865369364619255
Loss: 0.015492947772145271
Loss: 0.019171684980392456
Loss: 0.022700998932123184
Loss: 0.030046263709664345
Loss: 0.03841554373502731
Loss: 0.019631966948509216
Loss: 0.01679622009396553
Loss: 0.023311205208301544
Loss: 0.03165679797530174
Loss: 0.02817779779434204
Loss: 0.01498435065150261
Loss: 0.016916709020733833
Loss: 0.009516908787190914
Loss: 0.013914771378040314
Loss: 0.03198603168129921
Loss: 0.012353334575891495
Loss: 0.015339156612753868
Loss: 0.016120197251439095
Loss: 0.006420500576496124
Loss: 0.019626885652542114
Loss: 0.024988528341054916
Loss: 0.028647251427173615
Loss: 0.010206865146756172
Loss: 0.020918216556310654
Loss: 0.025295697152614594
Loss: 0.020878784358501434
Loss: 0.01758619397878647
Loss: 0.023583296686410904
Loss: 0.027050381526350975
Loss: 0.02011142671108246
Loss: 0.01409129612147808
Loss: 0.015736214816570282
Loss: 0.02060209959745407
Loss: 0.027128949761390686
Loss: 0.023446915671229362
Loss: 0.036001600325107574
Loss: 0.018511656671762466
Loss: 0.01920720376074314
Loss: 0.029864810407161713
Loss: 0.027200475335121155
Loss: 0.016171883791685104
Loss: 0.020199786871671677
Loss: 0.025286247953772545
Loss: 0.02033567801117897
Loss: 0.04276342689990997
Loss: 0.021857809275388718
Loss: 0.017168421298265457
Loss: 0.023361019790172577
Loss: 0.03044249303638935
Loss: 0.02784004807472229
Loss: 0.03880874812602997
Loss: 0.02639441192150116
Loss: 0.029883740469813347
Loss: 0.022406859323382378
Loss: 0.023495040833950043
Loss: 0.01571938954293728
Loss: 0.021098390221595764
Loss: 0.01676984690129757
Loss: 0.009640535339713097
Loss: 0.013287393376231194
Loss: 0.01931208185851574
Loss: 0.022366559132933617
Loss: 0.018939098343253136
Loss: 0.02624857984483242
Loss: 0.018784690648317337
Loss: 0.031175360083580017
Loss: 0.026192443445324898
Loss: 0.02186425030231476
Loss: 0.02652943879365921
Loss: 0.024367431178689003
Loss: 0.016740046441555023
Loss: 0.024467386305332184
Loss: 0.02558651939034462
Loss: 0.01736772432923317
Loss: 0.03328068554401398
Loss: 0.023520514369010925
Loss: 0.028924889862537384
Loss: 0.014891618862748146
Loss: 0.017437539994716644
Loss: 0.028767094016075134
Loss: 0.03257367014884949
Loss: 0.02516304701566696
Loss: 0.020238468423485756
Loss: 0.022964395582675934
Loss: 0.024343490600585938
Loss: 0.03130774572491646
Loss: 0.024128004908561707
Loss: 0.015969816595315933
Loss: 0.0356704480946064
Loss: 0.023618213832378387
Loss: 0.011910987086594105
Loss: 0.02276741713285446
Loss: 0.01601453870534897
Loss: 0.023953963071107864
Loss: 0.02076077088713646
Loss: 0.023621631786227226
Loss: 0.008149929344654083
Loss: 0.011193893849849701
Loss: 0.013779919594526291
Loss: 0.019075622782111168
Loss: 0.011332545429468155
Loss: 0.018374189734458923
Loss: 0.00786417443305254
Loss: 0.028014056384563446
Loss: 0.02040540799498558
Loss: 0.02935778722167015
Loss: 0.038291558623313904
Loss: 0.03702105954289436
Loss: 0.035803474485874176
Loss: 0.017483744770288467
Loss: 0.021001631394028664
Loss: 0.03384053707122803
Loss: 0.034847937524318695
Loss: 0.025064866989850998
Loss: 0.01403476856648922
Loss: 0.014985454268753529
Loss: 0.01871734857559204
Loss: 0.027287650853395462
Loss: 0.026096075773239136
Loss: 0.01895304024219513
Loss: 0.017183424904942513
Loss: 0.026206085458397865
Loss: 0.026633020490407944
Loss: 0.02216288447380066
Loss: 0.0564495213329792
Loss: 0.026784945279359818
Loss: 0.025381412357091904
Loss: 0.015770187601447105
Loss: 0.03381894528865814
Loss: 0.026263797655701637
Loss: 0.03165022283792496
Loss: 0.019144399091601372
Loss: 0.017231730744242668
Loss: 0.024026455357670784
Loss: 0.013367719948291779
Loss: 0.017525220289826393
Loss: 0.0162955354899168
Loss: 0.018693160265684128
Loss: 0.023483015596866608
Loss: 0.01597534865140915
Loss: 0.019978616386651993
Loss: 0.022129325196146965
Loss: 0.03937963768839836
Loss: 0.030721209943294525
Loss: 0.024508433416485786
Loss: 0.019966108724474907
Loss: 0.027386073023080826
Loss: 0.02077588625252247
Loss: 0.017833830788731575
Loss: 0.01819556951522827
Loss: 0.015298066660761833
Loss: 0.01772412098944187
Loss: 0.00913072470575571
Loss: 0.017517555505037308
Loss: 0.02916971780359745
Loss: 0.029484529048204422
Loss: 0.0165090374648571
Loss: 0.028805581852793694
Loss: 0.018195562064647675
Loss: 0.01519365981221199
Loss: 0.018158389255404472
Loss: 0.019854076206684113
Loss: 0.031852155923843384
Loss: 0.01860187202692032
Loss: 0.04604485630989075
Loss: 0.02576640620827675
Loss: 0.028568346053361893
Loss: 0.027869362384080887
Loss: 0.023324253037571907
Loss: 0.014252375811338425
Loss: 0.014558786526322365
Loss: 0.017063356935977936
Loss: 0.02867523767054081
Loss: 0.01717209443449974
Loss: 0.0275314599275589
Loss: 0.022404879331588745
Loss: 0.03226952999830246
Loss: 0.011252181604504585
Loss: 0.02064398303627968
Loss: 0.023048900067806244
Loss: 0.023910420015454292
Loss: 0.015921270474791527
Loss: 0.02091893181204796
Loss: 0.023140713572502136
Loss: 0.03254833072423935
Loss: 0.009130861610174179
Loss: 0.02315135858952999
Loss: 0.008089021779596806
Loss: 0.019211553037166595
Loss: 0.029322996735572815
Loss: 0.018330730497837067
Loss: 0.026580996811389923
Loss: 0.02034873142838478
Loss: 0.027433721348643303
Loss: 0.0277324877679348
Loss: 0.013611623086035252
Loss: 0.021129827946424484
Loss: 0.034579452127218246
Loss: 0.03219705820083618
Loss: 0.03291945159435272
Loss: 0.014857925474643707
Loss: 0.01701737567782402
Loss: 0.01582316681742668
Loss: 0.023910846561193466
Loss: 0.028317280113697052
Loss: 0.02134905755519867
Loss: 0.01620522327721119
Loss: 0.026204746216535568
Loss: 0.02195369079709053
Loss: 0.036061711609363556
Loss: 0.02561189793050289
Loss: 0.027346983551979065
Loss: 0.02108931541442871
Loss: 0.025453072041273117
Loss: 0.014583488926291466
Loss: 0.010639210231602192
Loss: 0.008199464529752731
Loss: 0.026678375899791718
Loss: 0.028658444061875343
Loss: 0.028008539229631424
Loss: 0.022333500906825066
Loss: 0.012294890359044075
Loss: 0.02797851897776127
Loss: 0.02465151622891426
Loss: 0.045541852712631226
Loss: 0.03206247463822365
Loss: 0.021125372499227524
Loss: 0.01975339464843273
Loss: 0.022532137110829353
Loss: 0.03348763287067413
Loss: 0.040923312306404114
Loss: 0.013663570396602154
Loss: 0.028191063553094864
Loss: 0.01757141947746277
Loss: 0.02143767476081848
Loss: 0.026715552434325218
Loss: 0.026797782629728317
Loss: 0.020081181079149246
Loss: 0.015572980977594852
Loss: 0.005896571557968855
Loss: 0.026485303416848183
Loss: 0.014912683516740799
Loss: 0.01448056660592556
Loss: 0.013832712545990944
Loss: 0.029979638755321503
Loss: 0.02159600891172886
Loss: 0.015216835774481297
Loss: 0.02701541781425476
Loss: 0.02643381804227829
Loss: 0.018457869067788124
Loss: 0.021007366478443146
Loss: 0.023309975862503052
Loss: 0.01844821311533451
Loss: 0.023704426363110542
Loss: 0.016951581463217735
Loss: 0.015772562474012375
Loss: 0.036733463406562805
Loss: 0.029061028733849525
Loss: 0.02313855290412903
Loss: 0.02438279613852501
Loss: 0.02331392839550972
Loss: 0.025476699694991112
Loss: 0.01946019008755684
Loss: 0.02864803746342659
Loss: 0.01749674789607525
Loss: 0.014430246315896511
Loss: 0.017035778611898422
Loss: 0.01911088451743126
Loss: 0.013544639572501183
Loss: 0.025872185826301575
Loss: 0.022109074518084526
Loss: 0.027618028223514557
Loss: 0.013509759679436684
Loss: 0.01674867607653141
Loss: 0.04614322632551193
Loss: 0.023746397346258163
Loss: 0.0319010466337204
Loss: 0.03761402145028114
Loss: 0.02665504440665245
Loss: 0.016361359506845474
Loss: 0.01571817882359028
Loss: 0.025462744757533073
Loss: 0.023143082857131958
Loss: 0.04236052185297012
Loss: 0.022046977654099464
Loss: 0.0151630574837327
Loss: 0.039357710629701614
Loss: 0.02524743601679802
In [ ]:
# torch.save(cm_unet.state_dict(), "cd_unet.pt")
In [ ]:
# cm_unet.load_state_dict(torch.load("cd_unet.pt"))
Out[ ]:
<All keys matched successfully>

Снова сэмплируем¶

Обратим внимание, что тут мы сэмпилруем без гайденса, потому что мы его уже частично прокинули в модель, когда делали шаг учителя с CFG.

Снова для референса приводим картинки на этом этапе:

img

Ваши картинки не обязаны совпадать: у вас могут быть немного менее/более качественные. Небольшая разница по качеству на оценку не влиет.

In [25]:
# Подставляем нашу новую обученную модель в пайплайн
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == "cd"

generator = torch.Generator(device="cuda").manual_seed(0)
guidance_scale = 1

images = consistency_sampling(
    pipe=pipe,
    prompt="A sad puppy with large eyes",
    generator=generator,
    num_images_per_prompt=4,
    guidance_scale=guidance_scale,
)


visualize_images(images)
/root/miniconda3/envs/pytorch-env/lib/python3.10/site-packages/peft/tuners/lora/model.py:375: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.
  return getattr(self.model, name)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

Давайте посмотрим на картинки для других промптов¶

In [ ]:
validation_prompts = [
    "A sad puppy with large eyes",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
    "A girl with pale blue hair and a cami tank top",
    "A lighthouse in a giant wave, origami style",
    "belle epoque, christmas, red house in the forest, photo realistic, 8k",
    "A small cactus with a happy face in the Sahara desert",
    "Green commercial building with refrigerator and refrigeration units outside",
]
In [ ]:
for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(0)

    images = consistency_sampling(
        pipe=pipe,
        prompt=prompt,
        generator=generator,
        num_images_per_prompt=4,
        guidance_scale=guidance_scale,
    )

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Multi-boundary Сonsistency Distillation¶

В конце мы рассмотрим недавнюю модификацию CD, Multi-boundary CD, где интегрируем не всю траекторию сразу и потом сэмплируем с возвращением назад, а разбиваем траектории на $K$ отрезков и применяет CD внутри каждого отрезка независимо. Например, на картинке выше у нас два отрезка: зеленым и красным выделены две граничные точки. Для классического CD, рассмотренного ранее, у нас только одна граничная точка в $t = 0$

Обратим внимание, что сэмплирование становится детерминистичным и можно снова использовать DDIM солвер, где число шагов равно числу интервалов $K$, на которые мы разбили траектории во время обучения.

Этот метод гораздо лучше работает чем обычный CD, потому что решать задачу CD на отрезках, а не на всей траектории, гораздо проще. В текущем задании мы разобьем траекторию на $K=4$ отрезка.

Подробнее почитать можно в этой статье.

Задание №7 (0.25 балла, сдается в контесте)¶

Ниже реализуйте функцию, которая для $K=4$ отрезков будет сопоставлять таймстепам соответствующие граничные точки.

Например, для $K=2$ отрезков граничные точки будут: [0, 499]

$0 \leq t < 499$ -> граничная точка - $0$

$499 \leq t < 999$ -> граничная точка - $499$

Замечание: помним, что интервал между $t$ и $s$ - 20 шагов.

In [ ]:
import torch


def get_multi_boundary_timesteps(
    timesteps,
    num_boundaries=4,
    num_timesteps=1000,
):
    """
    Для батча таймстепов определяем соответствующие граничные точки.
    params:
        timesteps: torch.Tensor(batch_size, device='cuda')
    returns:
        boundary_timesteps: torch.Tensor(batch_size, device='cuda')
    """
    # Здесь важно повыводить timesteps и boundary_timesteps перед обучением,
    # чтобы не перелетать граничные точки и при этом иногда попадать в них.
    step_size = 20

    boundary_points = torch.zeros_like(timesteps)

    step = num_timesteps // num_boundaries

    boundaries = torch.arange(0, num_timesteps - 1, step, device=timesteps.device).long()

    boundaries = boundaries - (boundaries > 0).long()

    for i, t in enumerate(timesteps):
        if t < 0:
            boundary_points[i] = 0
        else:
            boundary_points[i] = boundaries[boundaries <= t][-1]

    return boundary_points


timesteps = torch.tensor([-1, 0, 1, 498, 499, 500, 501, 998, 999, 1000])
num_boundaries = 4  # Implied by the step size and total timesteps
num_timesteps = 1000
step_size = 20

boundary_points = get_multi_boundary_timesteps(
    timesteps,
    num_boundaries,
    num_timesteps,
)
print(boundary_points)  # Outputs the boundary points for each timestep
tensor([  0,   0,   0, 249, 499, 499, 499, 749, 749, 749])
In [25]:
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
In [26]:
cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")

optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)

multi_cd_loss = functools.partial(
    cm_loss_template,
    loss_fn=pseudo_huber_loss,
    get_boundary_timesteps=get_multi_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_with_teacher,
)
assert cm_unet.active_adapter == "multi-cd"

Теперь обучим Multi-boundary CD модель

In [27]:
num_grad_accum = 2  # обновляем параметры каждые 2 шага

train_loop(cm_unet, pipe, train_dataloader, optimizer, multi_cd_loss, num_grad_accum)
  0%|          | 0/625 [00:00<?, ?it/s]
Loss: 0.0050458284094929695
Loss: 0.0049994164146482944
Loss: 0.0038845909293740988
Loss: 0.005223155952990055
Loss: 0.004456094931811094
Loss: 0.004738576244562864
Loss: 0.004106690175831318
Loss: 0.004323168657720089
Loss: 0.005026934668421745
Loss: 0.003846581093966961
Loss: 0.004206111188977957
Loss: 0.005668825935572386
Loss: 0.006089500617235899
Loss: 0.007177875842899084
Loss: 0.004582039080560207
Loss: 0.005050532519817352
Loss: 0.006586876232177019
Loss: 0.004225502721965313
Loss: 0.005653678439557552
Loss: 0.004238472785800695
Loss: 0.005283433478325605
Loss: 0.005087869241833687
Loss: 0.004655724857002497
Loss: 0.004986477084457874
Loss: 0.004622492007911205
Loss: 0.004199301823973656
Loss: 0.004161309450864792
Loss: 0.005230826325714588
Loss: 0.004746184218674898
Loss: 0.004358099773526192
Loss: 0.005007639527320862
Loss: 0.004936085548251867
Loss: 0.005806922912597656
Loss: 0.005879228003323078
Loss: 0.004877650644630194
Loss: 0.005153404548764229
Loss: 0.004234584979712963
Loss: 0.004089971072971821
Loss: 0.004984400235116482
Loss: 0.0037555380258709192
Loss: 0.004200669005513191
Loss: 0.005815621465444565
Loss: 0.0040777213871479034
Loss: 0.0038677502889186144
Loss: 0.0045357635244727135
Loss: 0.00579131068661809
Loss: 0.004972269758582115
Loss: 0.004561882931739092
Loss: 0.005287561099976301
Loss: 0.004652189090847969
Loss: 0.0053276135586202145
Loss: 0.005421602167189121
Loss: 0.004982382990419865
Loss: 0.005762035958468914
Loss: 0.003937865141779184
Loss: 0.006004189141094685
Loss: 0.004757110029459
Loss: 0.005100958980619907
Loss: 0.005118317902088165
Loss: 0.004297970328480005
Loss: 0.005094911437481642
Loss: 0.003917417023330927
Loss: 0.004822487011551857
Loss: 0.005608697421848774
Loss: 0.00601253192871809
Loss: 0.005718870088458061
Loss: 0.00427675386890769
Loss: 0.009903686121106148
Loss: 0.0044151367619633675
Loss: 0.005351589061319828
Loss: 0.003210116410627961
Loss: 0.004703446291387081
Loss: 0.004595904611051083
Loss: 0.0066633569076657295
Loss: 0.00471059326082468
Loss: 0.004878190346062183
Loss: 0.00623040646314621
Loss: 0.005952360108494759
Loss: 0.005907438695430756
Loss: 0.006119808182120323
Loss: 0.005179271101951599
Loss: 0.0046135964803397655
Loss: 0.005898504983633757
Loss: 0.004148378036916256
Loss: 0.005078914109617472
Loss: 0.005344546400010586
Loss: 0.00580773176625371
Loss: 0.00677518080919981
Loss: 0.005773784592747688
Loss: 0.005147865507751703
Loss: 0.0061773136258125305
Loss: 0.005787029396742582
Loss: 0.005595517810434103
Loss: 0.005134418606758118
Loss: 0.007175390608608723
Loss: 0.004991271533071995
Loss: 0.0049407826736569405
Loss: 0.0038854789454489946
Loss: 0.004828581586480141
Loss: 0.006333678029477596
Loss: 0.004753076937049627
Loss: 0.00465706130489707
Loss: 0.005885921884328127
Loss: 0.004284780006855726
Loss: 0.007528802379965782
Loss: 0.006056575570255518
Loss: 0.005483128130435944
Loss: 0.006047522649168968
Loss: 0.004117307718843222
Loss: 0.005675367079675198
Loss: 0.006207931321114302
Loss: 0.004521962720900774
Loss: 0.007268839981406927
Loss: 0.005056389141827822
Loss: 0.005315299145877361
Loss: 0.006997519638389349
Loss: 0.005880238488316536
Loss: 0.005715425591915846
Loss: 0.00600104033946991
Loss: 0.004890596494078636
Loss: 0.006446502171456814
Loss: 0.007522920146584511
Loss: 0.006871495395898819
Loss: 0.0044148582965135574
Loss: 0.004788204561918974
Loss: 0.007665774319320917
Loss: 0.005442911293357611
Loss: 0.005598905961960554
Loss: 0.007184822112321854
Loss: 0.005107589531689882
Loss: 0.00593933742493391
Loss: 0.005953479558229446
Loss: 0.004294685088098049
Loss: 0.006820393726229668
Loss: 0.006238329224288464
Loss: 0.004517987370491028
Loss: 0.007577195763587952
Loss: 0.004772569052875042
Loss: 0.0034602577798068523
Loss: 0.006062277127057314
Loss: 0.005392615683376789
Loss: 0.004645330831408501
Loss: 0.00564426789060235
Loss: 0.005072068423032761
Loss: 0.0051023103296756744
Loss: 0.006630321033298969
Loss: 0.004379334393888712
Loss: 0.004586417693644762
Loss: 0.00821531843394041
Loss: 0.006193439941853285
Loss: 0.004736592527478933
Loss: 0.0057826414704322815
Loss: 0.004980041179805994
Loss: 0.0058081671595573425
Loss: 0.006939011625945568
Loss: 0.006824206560850143
Loss: 0.006509170867502689
Loss: 0.005881378427147865
Loss: 0.00484424876049161
Loss: 0.005566502455621958
Loss: 0.004801922477781773
Loss: 0.005114580038934946
Loss: 0.004530574660748243
Loss: 0.0039250487461686134
Loss: 0.005654931999742985
Loss: 0.006495376117527485
Loss: 0.006822056137025356
Loss: 0.005995016545057297
Loss: 0.00372215174138546
Loss: 0.005231975112110376
Loss: 0.005164233967661858
Loss: 0.00551933329552412
Loss: 0.004443993791937828
Loss: 0.007678712718188763
Loss: 0.004685493651777506
Loss: 0.003813052549958229
Loss: 0.006269010249525309
Loss: 0.006094068754464388
Loss: 0.005769197829067707
Loss: 0.004189694300293922
Loss: 0.004578831605613232
Loss: 0.004750285763293505
Loss: 0.005363757722079754
Loss: 0.0070570590905845165
Loss: 0.0061873942613601685
Loss: 0.005730726756155491
Loss: 0.0046767136082053185
Loss: 0.005810233298689127
Loss: 0.006317509338259697
Loss: 0.006213617045432329
Loss: 0.004917256534099579
Loss: 0.005175063852220774
Loss: 0.005359809845685959
Loss: 0.007073342800140381
Loss: 0.007034657523036003
Loss: 0.005392714403569698
Loss: 0.004346674308180809
Loss: 0.005921985022723675
Loss: 0.00522988848388195
Loss: 0.0035362415947020054
Loss: 0.0043153902515769005
Loss: 0.004544549621641636
Loss: 0.005906911566853523
Loss: 0.00751439668238163
Loss: 0.006159099284559488
Loss: 0.006597897503525019
Loss: 0.0045716483145952225
Loss: 0.0038428332190960646
Loss: 0.005491739138960838
Loss: 0.005161176435649395
Loss: 0.004417429678142071
Loss: 0.005971512757241726
Loss: 0.006094412878155708
Loss: 0.005691769532859325
Loss: 0.004417594522237778
Loss: 0.004173712804913521
Loss: 0.005995198618620634
Loss: 0.005058352369815111
Loss: 0.006446205545216799
Loss: 0.004997555632144213
Loss: 0.007061402779072523
Loss: 0.004716943018138409
Loss: 0.0038628876209259033
Loss: 0.003458902705460787
Loss: 0.00551890954375267
Loss: 0.008518392220139503
Loss: 0.004729125648736954
Loss: 0.005206206813454628
Loss: 0.006306861061602831
Loss: 0.005356411449611187
Loss: 0.003877947572618723
Loss: 0.007570785935968161
Loss: 0.005646047182381153
Loss: 0.004399370402097702
Loss: 0.00495455926284194
Loss: 0.005674805957823992
Loss: 0.004989935085177422
Loss: 0.005953742191195488
Loss: 0.00609510438516736
Loss: 0.0028848438523709774
Loss: 0.003981401212513447
Loss: 0.004475269466638565
Loss: 0.005739680491387844
Loss: 0.0056493026204407215
Loss: 0.004683582112193108
Loss: 0.004250919446349144
Loss: 0.004005677066743374
Loss: 0.0069571021012961864
Loss: 0.0034569590352475643
Loss: 0.005023623816668987
Loss: 0.007230804301798344
Loss: 0.006312592886388302
Loss: 0.007624097168445587
Loss: 0.0038831306155771017
Loss: 0.007817307487130165
Loss: 0.005868788808584213
Loss: 0.004645034205168486
Loss: 0.005006207153201103
Loss: 0.0036685147788375616
Loss: 0.006274973973631859
Loss: 0.004201329778879881
Loss: 0.005159096326678991
Loss: 0.004047416616231203
Loss: 0.004243926145136356
Loss: 0.005854310002177954
Loss: 0.004543571267277002
Loss: 0.005086420103907585
Loss: 0.003759724786505103
Loss: 0.004923189990222454
Loss: 0.004369094967842102
Loss: 0.006080180872231722
Loss: 0.004846184980124235
Loss: 0.00456547224894166
Loss: 0.005371336359530687
Loss: 0.004979487042874098
Loss: 0.00532471714541316
Loss: 0.005537898279726505
Loss: 0.005277819000184536
Loss: 0.007113277912139893
Loss: 0.005858829244971275
Loss: 0.004342781379818916
Loss: 0.00618149945512414
Loss: 0.005807905923575163
Loss: 0.0044918665662407875
Loss: 0.005821374244987965
Loss: 0.005821592640131712
Loss: 0.0067418403923511505
Loss: 0.005330349318683147
Loss: 0.0048981960862874985
Loss: 0.004887910559773445
Loss: 0.00543246092274785
Loss: 0.004298563580960035
Loss: 0.005632053129374981
Loss: 0.007678564637899399
Loss: 0.0041793216951191425
Loss: 0.0038119428791105747
Loss: 0.0054582254961133
Loss: 0.007235650904476643
Loss: 0.005144327878952026
Loss: 0.005518809892237186
Loss: 0.0047637587413191795
Loss: 0.004467202350497246
Loss: 0.006964972708374262
Loss: 0.004814565647393465
Loss: 0.0072770630940794945
Loss: 0.007659660652279854
Loss: 0.006188603118062019
Loss: 0.003359612775966525
Loss: 0.0039827944710850716
Loss: 0.0041705830954015255
Loss: 0.004182526841759682
Loss: 0.004297685343772173
Loss: 0.005174009129405022
Loss: 0.00492203701287508
Loss: 0.004399633966386318
Loss: 0.004894674755632877
Loss: 0.004836228676140308
Loss: 0.005040745250880718
Loss: 0.003990960773080587
Loss: 0.004951309412717819
Loss: 0.0037009252700954676
Loss: 0.005258262623101473
Loss: 0.005119995214045048
Loss: 0.004841494373977184
Loss: 0.0043564909137785435
Loss: 0.006630360148847103
Loss: 0.0037280465476214886
Loss: 0.004347717855125666
Loss: 0.0036667990498244762
Loss: 0.005777742713689804
Loss: 0.00385462143458426
Loss: 0.005854490213096142
Loss: 0.004427099600434303
Loss: 0.005299945827573538
Loss: 0.003393057268112898
Loss: 0.0068469601683318615
Loss: 0.0037897927686572075
Loss: 0.00389404920861125
Loss: 0.0036998852156102657
Loss: 0.005442791618406773
Loss: 0.0037937595043331385
Loss: 0.004427595064043999
Loss: 0.0033183079212903976
Loss: 0.004206876270473003
Loss: 0.004943215288221836
Loss: 0.004394850227981806
Loss: 0.005368705373257399
Loss: 0.00518504623323679
Loss: 0.005397513508796692
Loss: 0.004134576302021742
Loss: 0.003661293536424637
Loss: 0.0072310250252485275
Loss: 0.00561707466840744
Loss: 0.005165843293070793
Loss: 0.006932005286216736
Loss: 0.005707199685275555
Loss: 0.0052117216400802135
Loss: 0.00782149750739336
Loss: 0.0037163624074310064
Loss: 0.006770831532776356
Loss: 0.0052643283270299435
Loss: 0.004664583597332239
Loss: 0.0058830538764595985
Loss: 0.005807559005916119
Loss: 0.004248068667948246
Loss: 0.006133052986115217
Loss: 0.004476846195757389
Loss: 0.004394937306642532
Loss: 0.004460309166461229
Loss: 0.005661469884216785
Loss: 0.0041225990280508995
Loss: 0.004351467825472355
Loss: 0.006362120620906353
Loss: 0.0043722535483539104
Loss: 0.004927045665681362
Loss: 0.005850121378898621
Loss: 0.0037984238006174564
Loss: 0.004901641979813576
Loss: 0.005034205503761768
Loss: 0.006304995156824589
Loss: 0.005167272407561541
Loss: 0.004179590381681919
Loss: 0.005129556637257338
Loss: 0.005819693207740784
Loss: 0.006344054825603962
Loss: 0.004218774847686291
Loss: 0.004722448997199535
Loss: 0.004006600938737392
Loss: 0.00467142416164279
Loss: 0.0038078853394836187
Loss: 0.0075129433535039425
Loss: 0.005242171697318554
Loss: 0.005739223212003708
Loss: 0.0035207541659474373
Loss: 0.005362201482057571
Loss: 0.0071684797294437885
Loss: 0.005062844604253769
Loss: 0.003998470492660999
Loss: 0.004947429522871971
Loss: 0.0036834331694990396
Loss: 0.008214157074689865
Loss: 0.004800674971193075
Loss: 0.006561334244906902
Loss: 0.004546507727354765
Loss: 0.004348032642155886
Loss: 0.004217318259179592
Loss: 0.00808870978653431
Loss: 0.007666631601750851
Loss: 0.0036950942594558
Loss: 0.003824038663879037
Loss: 0.006508697755634785
Loss: 0.004360540304332972
Loss: 0.0036826906725764275
Loss: 0.00482916971668601
Loss: 0.00620920117944479
Loss: 0.005616029724478722
Loss: 0.005163357127457857
Loss: 0.0037457679864019156
Loss: 0.0038069412112236023
Loss: 0.005739779211580753
Loss: 0.004011223558336496
Loss: 0.00474869180470705
Loss: 0.004466165788471699
Loss: 0.00393492728471756
Loss: 0.003772699972614646
Loss: 0.005259126424789429
Loss: 0.005252492614090443
Loss: 0.005206356290727854
Loss: 0.004400128498673439
Loss: 0.005464805290102959
Loss: 0.004545506555587053
Loss: 0.005229127127677202
Loss: 0.005244524218142033
Loss: 0.004195576533675194
Loss: 0.004898733925074339
Loss: 0.005880438722670078
Loss: 0.005460913293063641
Loss: 0.004988692235201597
Loss: 0.004901476204395294
Loss: 0.003930442500859499
Loss: 0.0038653179071843624
Loss: 0.0047339689917862415
Loss: 0.004692728631198406
Loss: 0.005765249487012625
Loss: 0.00711049372330308
Loss: 0.004393647890537977
Loss: 0.0048315743915736675
Loss: 0.006427568383514881
Loss: 0.004131690599024296
Loss: 0.0037154380697757006
Loss: 0.002854869933798909
Loss: 0.004418167285621166
Loss: 0.0051529440097510815
Loss: 0.0053727636113762856
Loss: 0.005450004246085882
Loss: 0.005508824251592159
Loss: 0.005146097391843796
Loss: 0.005144124384969473
Loss: 0.005186060443520546
Loss: 0.004354158416390419
Loss: 0.004796842113137245
Loss: 0.0049059330485761166
Loss: 0.004350817296653986
Loss: 0.0045591555535793304
Loss: 0.00465787760913372
Loss: 0.0052376920357346535
Loss: 0.005134403705596924
Loss: 0.004403100814670324
Loss: 0.004829405806958675
Loss: 0.00581878237426281
Loss: 0.0043157367035746574
Loss: 0.004899098537862301
Loss: 0.005846755113452673
Loss: 0.006113262847065926
Loss: 0.005257738288491964
Loss: 0.003660125657916069
Loss: 0.005620865151286125
Loss: 0.004661089740693569
Loss: 0.00572964595630765
Loss: 0.0060247257351875305
Loss: 0.005429534707218409
Loss: 0.004083451349288225
Loss: 0.005173895508050919
Loss: 0.006520335096865892
Loss: 0.006244057789444923
Loss: 0.005285394378006458
Loss: 0.004174549598246813
Loss: 0.004776251036673784
Loss: 0.004901340696960688
Loss: 0.003679657122120261
Loss: 0.005355440080165863
Loss: 0.004647457040846348
Loss: 0.00586352776736021
Loss: 0.005632366985082626
Loss: 0.0035310443490743637
Loss: 0.004305234644562006
Loss: 0.003553882474079728
Loss: 0.005134738050401211
Loss: 0.004639378748834133
Loss: 0.0028205523267388344
Loss: 0.006067552603781223
Loss: 0.005349949933588505
Loss: 0.004087365232408047
Loss: 0.005343243479728699
Loss: 0.004711148329079151
Loss: 0.0053049111738801
Loss: 0.005227005574852228
Loss: 0.0068631223402917385
Loss: 0.003459248458966613
Loss: 0.0049553439021110535
Loss: 0.0060326047241687775
Loss: 0.005049269646406174
Loss: 0.004182516597211361
Loss: 0.004878777079284191
Loss: 0.005989880301058292
Loss: 0.004208489786833525
Loss: 0.006403263658285141
Loss: 0.0034297527745366096
Loss: 0.005685040727257729
Loss: 0.003985801711678505
Loss: 0.004828311502933502
Loss: 0.005921314004808664
Loss: 0.005967257544398308
Loss: 0.005216366145759821
Loss: 0.006409239489585161
Loss: 0.004347200505435467
Loss: 0.005549816880375147
Loss: 0.005856629461050034
Loss: 0.006083352491259575
Loss: 0.0039038427639752626
Loss: 0.005274396855384111
Loss: 0.0046700905077159405
Loss: 0.006810040678828955
Loss: 0.004761847667396069
Loss: 0.005670033395290375
Loss: 0.004282340873032808
Loss: 0.0070357671938836575
Loss: 0.0049394043162465096
Loss: 0.004041735082864761
Loss: 0.005748603492975235
Loss: 0.0051040975376963615
Loss: 0.005032747518271208
Loss: 0.005162765737622976
Loss: 0.006514610256999731
Loss: 0.004632913041859865
Loss: 0.004685945808887482
Loss: 0.005423092283308506
Loss: 0.005085199140012264
Loss: 0.006271375808864832
Loss: 0.005126899108290672
Loss: 0.005000355653464794
Loss: 0.004179814830422401
Loss: 0.003980487119406462
Loss: 0.0037012884858995676
Loss: 0.005241088569164276
Loss: 0.00599312037229538
Loss: 0.004957483150064945
Loss: 0.004777251742780209
Loss: 0.006968651432543993
Loss: 0.004508870653808117
Loss: 0.00495144072920084
Loss: 0.006375170312821865
Loss: 0.005465688183903694
Loss: 0.005821937695145607
Loss: 0.004899467807263136
Loss: 0.005708535201847553
Loss: 0.00423073535785079
Loss: 0.005874833557754755
Loss: 0.004133216105401516
Loss: 0.004646312445402145
Loss: 0.005897577852010727
Loss: 0.006291603669524193
Loss: 0.007561618462204933
Loss: 0.006860073190182447
Loss: 0.003710885066539049
Loss: 0.0052077993750572205
Loss: 0.006266698706895113
Loss: 0.005054624751210213
Loss: 0.0038261539302766323
Loss: 0.003448259085416794
Loss: 0.00894455797970295
Loss: 0.005979649722576141
Loss: 0.00582055002450943
Loss: 0.005652684718370438
Loss: 0.0067834677174687386
Loss: 0.0031695635989308357
Loss: 0.008927084505558014
Loss: 0.003964771516621113
Loss: 0.005587117746472359
Loss: 0.007189561612904072
Loss: 0.006516185589134693
Loss: 0.007592527661472559
Loss: 0.004328534007072449
Loss: 0.005083056632429361
Loss: 0.004025323782116175
Loss: 0.005534962750971317
Loss: 0.0044294619001448154
Loss: 0.007328921463340521
Loss: 0.006104005500674248
Loss: 0.00541450222954154
Loss: 0.0064305211417376995
Loss: 0.004474438726902008
Loss: 0.007335943169891834
Loss: 0.005208642687648535
Loss: 0.00733292056247592
Loss: 0.006287113297730684
Loss: 0.0038627712056040764
Loss: 0.006797074340283871
Loss: 0.004658001475036144
Loss: 0.006570708472281694
Loss: 0.0049662841483950615
Loss: 0.006000400520861149
Loss: 0.006568297743797302
Loss: 0.006987361237406731
Loss: 0.004486429039388895
Loss: 0.0037397085689008236
Loss: 0.006019908003509045
Loss: 0.004593628458678722
Loss: 0.0035399883054196835
Loss: 0.004847708158195019
Loss: 0.0043555619195103645
Loss: 0.005846542306244373
Loss: 0.006565350107848644
Loss: 0.006879416760057211
Loss: 0.0040451898239552975
In [ ]:
# torch.save(cm_unet.state_dict(), "mb_unet.pt")
In [ ]:
# cm_unet.load_state_dict(torch.load("mb_unet.pt"))

И в последний раз сэмплируем¶

Важно: теперь у нас появляется возможно сэмплировать детерминистично с помощью оригинального солвера DDIM за 4 шага. Так что возвращаем сэмплирование исходным pipe-ом.

Ниже прикрепляем референс и напомним, что у вас картинки могут отличаться и быть чуть хуже/лучше. img

In [31]:
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == "multi-cd"

guidance_scale = 1

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)

    images = pipe(
        prompt,
        generator=generator,
        num_inference_steps=4,
        guidance_scale=guidance_scale,
        num_images_per_prompt=4,
    ).images  # type: ignore

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Задание №8¶

Все, что осталось сделать - это загрузить ваши обученные модельки на huggingface_hub. Это очень популярный и удобный способ для хранения моделей, которые легко можно загружать и подставлять в модель. Другими словами GitHub для моделей и датасетов.

  1. Создайте аккаунт на huggingface.co

  2. Получите свой HF токен, который можно получить здесь: https://huggingface.co/settings/tokens

  3. Создайте репозиторий для ваших моделями https://huggingface.co/new

Важно: перед отправкой нотбука на проверку, не забудьте удалить свой HF токен!

In [32]:
cm_unet.push_to_hub(
    "jd-salinger/cv-week-final-task",  # "<username>/<repo-name>"
    token="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",
)
README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]
Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]
adapter_model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]
Out[32]:
CommitInfo(commit_url='https://huggingface.co/jd-salinger/cv-week-final-task/commit/be53414d0dd081ab5d1892f3c9c46a4ecb715f6c', commit_message='Upload model', commit_description='', oid='be53414d0dd081ab5d1892f3c9c46a4ecb715f6c', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jd-salinger/cv-week-final-task', endpoint='https://huggingface.co', repo_type='model', repo_id='jd-salinger/cv-week-final-task'), pr_revision=None, pr_num=None)

Пример, как должен выглядеть результат выполнения команды: https://huggingface.co/dbaranchuk/cv-week-final-task-example

Давайте проверим, что загрузка модели корректно работает.

In [34]:
from peft import PeftModel

loaded_cm_unet = PeftModel.from_pretrained(
    unet,
    "jd-salinger/cv-week-final-task",
    token="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",
    subfolder="multi-cd",
    adapter_name="multi-cd",
)
multi-cd/adapter_config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]
In [35]:
pipe.unet = loaded_cm_unet.eval().to(torch.float16)
assert loaded_cm_unet.active_adapter == "multi-cd"

guidance_scale = 1

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)

    images = pipe(
        prompt,
        generator=generator,
        num_inference_steps=4,
        guidance_scale=guidance_scale,
        num_images_per_prompt=4,
    ).images  # type: ignore

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

На этом все! Ура!

No description has been provided for this image

P.S. Некоторые примеры плохих генераций, которые могут возникать при выполнении задания¶

Неправильный сэмплинг¶

img

img

Ошибки в обучении¶

img img

Необученная модель¶

img

In [ ]: